Commit 0c02927d by tongtao.ling

Initial commit

parent 53a872fa
.vscode/
__pycache__/
outputs/
lightning_logs/
mengzi-t5-base/
\ No newline at end of file
MIT License
Copyright (c) 2021 Shivanand Roy
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
<img align="center" src="data/st5.png" alt="simpleT5">
<p align="center">
<b>Quickly train T5/mT5/byT5/CodeT5 models in just 3 lines of code
</b>
</p>
<p align="center">
<a href="https://badge.fury.io/py/simplet5"><img src="https://badge.fury.io/py/simplet5.svg" alt="PyPI version" height="18"></a>
<a href="https://badge.fury.io/py/simplet5">
<img alt="Stars" src="https://img.shields.io/github/stars/Shivanandroy/simpleT5?color=blue">
</a>
<a href="https://pepy.tech/project/simplet5">
<img alt="Stats" src="https://static.pepy.tech/personalized-badge/simplet5?period=total&units=international_system&left_color=black&right_color=brightgreen&left_text=Downloads">
</a>
<a href="https://opensource.org/licenses/MIT">
<img alt="License" src="https://img.shields.io/badge/License-MIT-yellow.svg">
</a>
</p>
**simpleT5** is built on top of PyTorch-lightning⚡️ and Transformers🤗 that lets you quickly train your T5 models.
> T5 models can be used for several NLP tasks such as summarization, QA , QG , translation , text generation, and more.
Here's a link to [Medium article](https://snrspeaks.medium.com/simplet5-train-t5-models-in-just-3-lines-of-code-by-shivanand-roy-2021-354df5ae46ba) along with an [example colab notebook](https://colab.research.google.com/drive/1JZ8v9L0w0Ai3WbibTeuvYlytn0uHMP6O?usp=sharing)
## Install
```python
# It's advisable to create a new python environment and install simplet5
pip install --upgrade simplet5
```
## Usage
**simpleT5** for summarization task [![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JZ8v9L0w0Ai3WbibTeuvYlytn0uHMP6O?usp=sharing)
```python
# import
from simplet5 import SimpleT5
# instantiate
model = SimpleT5()
# load (supports t5, mt5, byT5 and CodeT5 models)
model.from_pretrained("t5","t5-base")
# train
model.train(train_df=train_df, # pandas dataframe with 2 columns: source_text & target_text
eval_df=eval_df, # pandas dataframe with 2 columns: source_text & target_text
source_max_token_len = 512,
target_max_token_len = 128,
batch_size = 8,
max_epochs = 5,
use_gpu = True,
outputdir = "outputs",
early_stopping_patience_epochs = 0,
precision = 32
)
# load trained T5 model
model.load_model("t5","path/to/trained/model/directory", use_gpu=False)
# predict
model.predict("input text for prediction")
```
## Articles
- [Geek Culture: simpleT5 — Train T5 Models in Just 3 Lines of Code](https://medium.com/geekculture/simplet5-train-t5-models-in-just-3-lines-of-code-by-shivanand-roy-2021-354df5ae46ba)
- [Abstractive Summarization with SimpleT5⚡️](https://snrspeaks.medium.com/abstractive-summarization-with-simplet5-%EF%B8%8F-344a78f73265)
- [Training T5 model in just 3 lines of Code with ONNX Inference](https://medium.com/mlearning-ai/training-t5-model-in-just-3-lines-of-code-with-onnx-inference-ff5b6678c757)
- [Kaggle: simpleT5⚡️ - Generating one line summary of papers](https://www.kaggle.com/mathurinache/simplet5-generating-one-line-summary-of-papers)
- [Youtube: Abstractive Summarization Demo with SimpleT5](https://www.youtube.com/watch?v=jgKj-7v2UYU)
## Acknowledgements
- [Transformers by HuggingFace 🤗](https://huggingface.co/transformers/)
- [Pytorch Lightning ⚡️](https://www.pytorchlightning.ai/)
- [Fastt5](https://github.com/Ki6an/fastT5)
<img align="center" src="data/st5.png" alt="simpleT5">
<p align="center">
<b>Quickly train T5/mT5/byT5/CodeT5 models in just 3 lines of code
</b>
</p>
<p align="center">
<a href="https://badge.fury.io/py/simplet5"><img src="https://badge.fury.io/py/simplet5.svg" alt="PyPI version" height="18"></a>
<a href="https://badge.fury.io/py/simplet5">
<img alt="Stars" src="https://img.shields.io/github/stars/Shivanandroy/simpleT5?color=blue">
</a>
<a href="https://pepy.tech/project/simplet5">
<img alt="Stats" src="https://static.pepy.tech/personalized-badge/simplet5?period=total&units=international_system&left_color=black&right_color=brightgreen&left_text=Downloads">
</a>
<a href="https://opensource.org/licenses/MIT">
<img alt="License" src="https://img.shields.io/badge/License-MIT-yellow.svg">
</a>
</p>
**simpleT5** is built on top of PyTorch-lightning⚡️ and Transformers🤗 that lets you quickly train your T5 models.
> T5 models can be used for several NLP tasks such as summarization, QA , QG , translation , text generation, and more.
Here's a link to [Medium article](https://snrspeaks.medium.com/simplet5-train-t5-models-in-just-3-lines-of-code-by-shivanand-roy-2021-354df5ae46ba) along with an [example colab notebook](https://colab.research.google.com/drive/1JZ8v9L0w0Ai3WbibTeuvYlytn0uHMP6O?usp=sharing)
## Install
```python
# It's advisable to create a new python environment and install simplet5
pip install --upgrade simplet5
```
## Usage
**simpleT5** for summarization task [![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JZ8v9L0w0Ai3WbibTeuvYlytn0uHMP6O?usp=sharing)
```python
# import
from simplet5 import SimpleT5
# instantiate
model = SimpleT5()
# load (supports t5, mt5, byT5 and CodeT5 models)
model.from_pretrained("t5","t5-base")
# train
model.train(train_df=train_df, # pandas dataframe with 2 columns: source_text & target_text
eval_df=eval_df, # pandas dataframe with 2 columns: source_text & target_text
source_max_token_len = 512,
target_max_token_len = 128,
batch_size = 8,
max_epochs = 5,
use_gpu = True,
outputdir = "outputs",
early_stopping_patience_epochs = 0,
precision = 32
)
# load trained T5 model
model.load_model("t5","path/to/trained/model/directory", use_gpu=False)
# predict
model.predict("input text for prediction")
```
## Articles
- [Geek Culture: simpleT5 — Train T5 Models in Just 3 Lines of Code](https://medium.com/geekculture/simplet5-train-t5-models-in-just-3-lines-of-code-by-shivanand-roy-2021-354df5ae46ba)
- [Abstractive Summarization with SimpleT5⚡️](https://snrspeaks.medium.com/abstractive-summarization-with-simplet5-%EF%B8%8F-344a78f73265)
- [Training T5 model in just 3 lines of Code with ONNX Inference](https://medium.com/mlearning-ai/training-t5-model-in-just-3-lines-of-code-with-onnx-inference-ff5b6678c757)
- [Kaggle: simpleT5⚡️ - Generating one line summary of papers](https://www.kaggle.com/mathurinache/simplet5-generating-one-line-summary-of-papers)
- [Youtube: Abstractive Summarization Demo with SimpleT5](https://www.youtube.com/watch?v=jgKj-7v2UYU)
## Acknowledgements
- [Transformers by HuggingFace 🤗](https://huggingface.co/transformers/)
- [Pytorch Lightning ⚡️](https://www.pytorchlightning.ai/)
- [Fastt5](https://github.com/Ki6an/fastT5)
from fastapi import FastAPI, Request
import uvicorn, json, datetime, requests, os, time
import torch
import os
from simplet5 import SimpleT5
from transformers import T5Tokenizer, T5ForConditionalGeneration
PORT = 5555
app = FastAPI()
model = SimpleT5()
model.tokenizer = T5Tokenizer.from_pretrained("./mengzi-t5-base")
model.model = T5ForConditionalGeneration.from_pretrained("./outputs/simplet5-epoch-4-train-loss-0.1129-val-loss-0.2851")
model.device = torch.device("cpu")
@app.post("/disflu")
async def disfluency_detection(request: Request):
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
text = json_post_list.get('text')
print(text)
result = model.predict(text)
now = datetime.datetime.now()
now_time = now.strftime("%Y-%m-%d %H:%M:%S")
response = {
"result": result,
"status": 200,
"time": now_time
}
return response
if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=PORT, workers=1)
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
import setuptools
from os import path
here = path.abspath(path.dirname(__file__))
with open(path.join(here, "README.md"), encoding="utf-8") as f:
long_description = f.read()
setuptools.setup(
name="simplet5",
version="0.1.4",
license="apache-2.0",
author="Shivanand Roy",
author_email="shivanandroy.official@gmail.com",
description="simpleT5 is built on top of PyTorch-lightning ⚡️ and Transformers 🤗 that lets you quickly train your T5 models.",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/Shivanandroy/simpleT5",
project_urls={
"Repo": "https://github.com/Shivanandroy/simpleT5",
"Bug Tracker": "https://github.com/Shivanandroy/simpleT5/issues",
},
keywords=[
"T5",
"simpleT5",
"transformers",
"NLP",
"finetune",
"fine-tuning",
"pytorch",
"summarization",
"translation",
"training",
"classification",
"Q&A",
"inference",
"fast inference",
],
packages=setuptools.find_packages(),
python_requires=">=3.5",
install_requires=[
"numpy",
"pandas",
"sentencepiece",
"torch>=1.7.0,!=1.8.0", # excludes torch v1.8.0
"transformers==4.16.2",
"pytorch-lightning==1.5.10",
],
classifiers=[
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
],
)
from .simplet5 import SimpleT5
import torch
import numpy as np
import pandas as pd
from transformers import (
T5ForConditionalGeneration,
MT5ForConditionalGeneration,
ByT5Tokenizer,
PreTrainedTokenizer,
RobertaTokenizer,
T5TokenizerFast as T5Tokenizer,
MT5TokenizerFast as MT5Tokenizer,
)
from transformers import AutoTokenizer
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelWithLMHead, AutoTokenizer
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.progress import TQDMProgressBar
torch.cuda.empty_cache()
pl.seed_everything(42)
class PyTorchDataModule(Dataset):
""" PyTorch Dataset class """
def __init__(
self,
data: pd.DataFrame,
tokenizer: PreTrainedTokenizer,
source_max_token_len: int = 512,
target_max_token_len: int = 512,
):
"""
initiates a PyTorch Dataset Module for input data
Args:
data (pd.DataFrame): input pandas dataframe. Dataframe must have 2 column --> "source_text" and "target_text"
tokenizer (PreTrainedTokenizer): a PreTrainedTokenizer (T5Tokenizer, MT5Tokenizer, ByT5Tokenizer, or RobertaTokenizer)
source_max_token_len (int, optional): max token length of source text. Defaults to 512.
target_max_token_len (int, optional): max token length of target text. Defaults to 512.
"""
self.tokenizer = tokenizer
self.data = data
self.source_max_token_len = source_max_token_len
self.target_max_token_len = target_max_token_len
def __len__(self):
""" returns length of data """
return len(self.data)
def __getitem__(self, index: int):
""" returns dictionary of input tensors to feed into T5/MT5 model"""
data_row = self.data.iloc[index]
source_text = data_row["source_text"]
source_text_encoding = self.tokenizer(
source_text,
max_length=self.source_max_token_len,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
target_text_encoding = self.tokenizer(
data_row["target_text"],
max_length=self.target_max_token_len,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
labels = target_text_encoding["input_ids"]
labels[
labels == 0
] = -100 # to make sure we have correct labels for T5 text generation
return dict(
source_text_input_ids=source_text_encoding["input_ids"].flatten(),
source_text_attention_mask=source_text_encoding["attention_mask"].flatten(),
labels=labels.flatten(),
labels_attention_mask=target_text_encoding["attention_mask"].flatten(),
)
class LightningDataModule(pl.LightningDataModule):
""" PyTorch Lightning data class """
def __init__(
self,
train_df: pd.DataFrame,
test_df: pd.DataFrame,
tokenizer: PreTrainedTokenizer,
batch_size: int = 4,
source_max_token_len: int = 512,
target_max_token_len: int = 512,
num_workers: int = 2,
):
"""
initiates a PyTorch Lightning Data Module
Args:
train_df (pd.DataFrame): training dataframe. Dataframe must contain 2 columns --> "source_text" & "target_text"
test_df (pd.DataFrame): validation dataframe. Dataframe must contain 2 columns --> "source_text" & "target_text"
tokenizer (PreTrainedTokenizer): PreTrainedTokenizer (T5Tokenizer, MT5Tokenizer, ByT5Tokenizer, or RobertaTokenizer)
batch_size (int, optional): batch size. Defaults to 4.
source_max_token_len (int, optional): max token length of source text. Defaults to 512.
target_max_token_len (int, optional): max token length of target text. Defaults to 512.
"""
super().__init__()
self.train_df = train_df
self.test_df = test_df
self.batch_size = batch_size
self.tokenizer = tokenizer
self.source_max_token_len = source_max_token_len
self.target_max_token_len = target_max_token_len
self.num_workers = num_workers
def setup(self, stage=None):
self.train_dataset = PyTorchDataModule(
self.train_df,
self.tokenizer,
self.source_max_token_len,
self.target_max_token_len,
)
self.test_dataset = PyTorchDataModule(
self.test_df,
self.tokenizer,
self.source_max_token_len,
self.target_max_token_len,
)
def train_dataloader(self):
""" training dataloader """
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
)
def test_dataloader(self):
""" test dataloader """
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
def val_dataloader(self):
""" validation dataloader """
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
class LightningModel(pl.LightningModule):
""" PyTorch Lightning Model class"""
def __init__(
self,
tokenizer,
model,
outputdir: str = "outputs",
save_only_last_epoch: bool = False,
):
"""
initiates a PyTorch Lightning Model
Args:
tokenizer : T5/MT5/ByT5/CodeT5 tokenizer
model : T5/MT5/ByT5/CodeT5 model
outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
save_only_last_epoch (bool, optional): If True, save just the last epoch else models are saved for every epoch
"""
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.outputdir = outputdir
self.average_training_loss = None
self.average_validation_loss = None
self.save_only_last_epoch = save_only_last_epoch
def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
""" forward step """
output = self.model(
input_ids,
attention_mask=attention_mask,
labels=labels,
decoder_attention_mask=decoder_attention_mask,
)
return output.loss, output.logits
def training_step(self, batch, batch_size):
""" training step """
input_ids = batch["source_text_input_ids"]
attention_mask = batch["source_text_attention_mask"]
labels = batch["labels"]
labels_attention_mask = batch["labels_attention_mask"]
loss, outputs = self(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_attention_mask=labels_attention_mask,
labels=labels,
)
self.log(
"train_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True
)
return loss
def validation_step(self, batch, batch_size):
""" validation step """
input_ids = batch["source_text_input_ids"]
attention_mask = batch["source_text_attention_mask"]
labels = batch["labels"]
labels_attention_mask = batch["labels_attention_mask"]
loss, outputs = self(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_attention_mask=labels_attention_mask,
labels=labels,
)
self.log(
"val_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True
)
return loss
def test_step(self, batch, batch_size):
""" test step """
input_ids = batch["source_text_input_ids"]
attention_mask = batch["source_text_attention_mask"]
labels = batch["labels"]
labels_attention_mask = batch["labels_attention_mask"]
loss, outputs = self(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_attention_mask=labels_attention_mask,
labels=labels,
)
self.log("test_loss", loss, prog_bar=True, logger=True)
return loss
def configure_optimizers(self):
""" configure optimizers """
return AdamW(self.parameters(), lr=0.0001)
def training_epoch_end(self, training_step_outputs):
""" save tokenizer and model on epoch end """
self.average_training_loss = np.round(
torch.mean(torch.stack([x["loss"] for x in training_step_outputs])).item(),
4,
)
path = f"{self.outputdir}/simplet5-epoch-{self.current_epoch}-train-loss-{str(self.average_training_loss)}-val-loss-{str(self.average_validation_loss)}"
if self.save_only_last_epoch:
if self.current_epoch == self.trainer.max_epochs - 1:
self.tokenizer.save_pretrained(path)
self.model.save_pretrained(path)
else:
self.tokenizer.save_pretrained(path)
self.model.save_pretrained(path)
def validation_epoch_end(self, validation_step_outputs):
_loss = [x.cpu() for x in validation_step_outputs]
self.average_validation_loss = np.round(
torch.mean(torch.stack(_loss)).item(),
4,
)
class SimpleT5:
""" Custom SimpleT5 class """
def __init__(self) -> None:
""" initiates SimpleT5 class """
pass
def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
"""
loads T5/MT5 Model model for training/finetuning
Args:
model_type (str, optional): "t5" or "mt5" . Defaults to "t5".
model_name (str, optional): exact model architecture name, "t5-base" or "t5-large". Defaults to "t5-base".
"""
if model_type == "t5":
self.tokenizer = T5Tokenizer.from_pretrained(f"{model_name}")
self.model = T5ForConditionalGeneration.from_pretrained(
f"{model_name}", return_dict=True
)
elif model_type == "mt5":
self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_name}")
self.model = MT5ForConditionalGeneration.from_pretrained(
f"{model_name}", return_dict=True
)
elif model_type == "byt5":
self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_name}")
self.model = T5ForConditionalGeneration.from_pretrained(
f"{model_name}", return_dict=True
)
elif model_type =="codet5":
self.tokenizer = RobertaTokenizer.from_pretrained(f"{model_name}")
self.model = T5ForConditionalGeneration.from_pretrained(
f"{model_name}", return_dict=True
)
def train(
self,
train_df: pd.DataFrame,
eval_df: pd.DataFrame,
source_max_token_len: int = 512,
target_max_token_len: int = 512,
batch_size: int = 8,
max_epochs: int = 5,
use_gpu: bool = True,
outputdir: str = "outputs",
early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
precision=32,
logger="default",
dataloader_num_workers: int = 2,
save_only_last_epoch: bool = False,
):
"""
trains T5/MT5 model on custom dataset
Args:
train_df (pd.DataFrame): training datarame. Dataframe must have 2 column --> "source_text" and "target_text"
eval_df ([type], optional): validation datarame. Dataframe must have 2 column --> "source_text" and "target_text"
source_max_token_len (int, optional): max token length of source text. Defaults to 512.
target_max_token_len (int, optional): max token length of target text. Defaults to 512.
batch_size (int, optional): batch size. Defaults to 8.
max_epochs (int, optional): max number of epochs. Defaults to 5.
use_gpu (bool, optional): if True, model uses gpu for training. Defaults to True.
outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
early_stopping_patience_epochs (int, optional): monitors val_loss on epoch end and stops training, if val_loss does not improve after the specied number of epochs. set 0 to disable early stopping. Defaults to 0 (disabled)
precision (int, optional): sets precision training - Double precision (64), full precision (32) or half precision (16). Defaults to 32.
logger (pytorch_lightning.loggers) : any logger supported by PyTorch Lightning. Defaults to "default". If "default", pytorch lightning default logger is used.
dataloader_num_workers (int, optional): number of workers in train/test/val dataloader
save_only_last_epoch (bool, optional): If True, saves only the last epoch else models are saved at every epoch
"""
self.data_module = LightningDataModule(
train_df,
eval_df,
self.tokenizer,
batch_size=batch_size,
source_max_token_len=source_max_token_len,
target_max_token_len=target_max_token_len,
num_workers=dataloader_num_workers,
)
self.T5Model = LightningModel(
tokenizer=self.tokenizer,
model=self.model,
outputdir=outputdir,
save_only_last_epoch=save_only_last_epoch,
)
# add callbacks
callbacks = [TQDMProgressBar(refresh_rate=5)]
if early_stopping_patience_epochs > 0:
early_stop_callback = EarlyStopping(
monitor="val_loss",
min_delta=0.00,
patience=early_stopping_patience_epochs,
verbose=True,
mode="min",
)
callbacks.append(early_stop_callback)
# add gpu support
gpus = 1 if use_gpu else 0
# add logger
loggers = True if logger == "default" else logger
# prepare trainer
trainer = pl.Trainer(
logger=loggers,
callbacks=callbacks,
max_epochs=max_epochs,
gpus=gpus,
precision=precision,
log_every_n_steps=1,
)
# fit trainer
trainer.fit(self.T5Model, self.data_module)
def load_model(
self, model_type: str = "t5", model_dir: str = "outputs", use_gpu: bool = False
):
"""
loads a checkpoint for inferencing/prediction
Args:
model_type (str, optional): "t5" or "mt5". Defaults to "t5".
model_dir (str, optional): path to model directory. Defaults to "outputs".
use_gpu (bool, optional): if True, model uses gpu for inferencing/prediction. Defaults to True.
"""
if model_type == "t5":
self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
self.tokenizer = T5Tokenizer.from_pretrained(f"{model_dir}")
elif model_type == "mt5":
self.model = MT5ForConditionalGeneration.from_pretrained(f"{model_dir}")
self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_dir}")
elif model_type == "byt5":
self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_dir}")
elif model_type =="codet5":
self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
self.tokenizer = RobertaTokenizer.from_pretrained(f"{model_dir}")
if use_gpu:
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
raise "exception ---> no gpu found. set use_gpu=False, to use CPU"
else:
self.device = torch.device("cpu")
self.model = self.model.to(self.device)
def predict(
self,
source_text: str,
max_length: int = 512,
num_return_sequences: int = 1,
num_beams: int = 2,
top_k: int = 50,
top_p: float = 0.95,
do_sample: bool = True,
repetition_penalty: float = 2.5,
length_penalty: float = 1.0,
early_stopping: bool = True,
skip_special_tokens: bool = True,
clean_up_tokenization_spaces: bool = True,
):
"""
generates prediction for T5/MT5 model
Args:
source_text (str): any text for generating predictions
max_length (int, optional): max token length of prediction. Defaults to 512.
num_return_sequences (int, optional): number of predictions to be returned. Defaults to 1.
num_beams (int, optional): number of beams. Defaults to 2.
top_k (int, optional): Defaults to 50.
top_p (float, optional): Defaults to 0.95.
do_sample (bool, optional): Defaults to True.
repetition_penalty (float, optional): Defaults to 2.5.
length_penalty (float, optional): Defaults to 1.0.
early_stopping (bool, optional): Defaults to True.
skip_special_tokens (bool, optional): Defaults to True.
clean_up_tokenization_spaces (bool, optional): Defaults to True.
Returns:
list[str]: returns predictions
"""
input_ids = self.tokenizer.encode(
source_text, return_tensors="pt", add_special_tokens=True
)
input_ids = input_ids.to(self.device)
generated_ids = self.model.generate(
input_ids=input_ids,
num_beams=num_beams,
max_length=max_length,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
early_stopping=early_stopping,
top_p=top_p,
top_k=top_k,
num_return_sequences=num_return_sequences,
)
preds = [
self.tokenizer.decode(
g,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
for g in generated_ids
]
return preds
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/tongtao.ling/miniconda3/envs/t5/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Global seed set to 42\n",
"GPU available: True, used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]\n",
"\n",
" | Name | Type | Params\n",
"-----------------------------------------------------\n",
"0 | model | T5ForConditionalGeneration | 247 M \n",
"-----------------------------------------------------\n",
"247 M Trainable params\n",
"0 Non-trainable params\n",
"247 M Total params\n",
"990.311 Total estimated model params size (MB)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Validation sanity check: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/tongtao.ling/miniconda3/envs/t5/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
" rank_zero_warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" "
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Global seed set to 42\n",
"/home/tongtao.ling/miniconda3/envs/t5/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
" rank_zero_warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4: 100%|██████████| 376/376 [01:29<00:00, 4.19it/s, loss=0.129, v_num=2, train_loss_step=0.126, val_loss_step=0.240, val_loss_epoch=0.286, train_loss_epoch=0.113] \n"
]
}
],
"source": [
"# import\n",
"from simplet5 import SimpleT5\n",
"from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
"import pandas as pd\n",
"import os\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
"# instantiate\n",
"model = SimpleT5()\n",
"\n",
"model.tokenizer = T5Tokenizer.from_pretrained(\"./mengzi-t5-base\")\n",
"model.model = T5ForConditionalGeneration.from_pretrained(\"./mengzi-t5-base\")\n",
"# load (supports t5, mt5, byT5 and CodeT5 models)\n",
"# model.from_pretrained(\"./mengzi-t5-base\")\n",
"\n",
"\n",
"df = pd.read_csv(\"./data/raw.csv\")\n",
"train_df = df[:2700]\n",
"eval_df = df[2700:]\n",
"# train\n",
"model.train(train_df=train_df, # pandas dataframe with 2 columns: source_text & target_text\n",
" eval_df=eval_df, # pandas dataframe with 2 columns: source_text & target_text\n",
" source_max_token_len = 128, \n",
" target_max_token_len = 128,\n",
" batch_size = 8,\n",
" max_epochs = 5,\n",
" use_gpu = True,\n",
" outputdir = \"outputs\",\n",
" early_stopping_patience_epochs = 0,\n",
" precision = 32\n",
" )\n",
"\n",
"# # load trained T5 model\n",
"# model.load_model(\"t5\",\"path/to/trained/model/directory\", use_gpu=False)\n",
"\n",
"# # predict\n",
"# model.predict(\"input text for prediction\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['现在孩子大了已经读初中了,我自己也反思了。']"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# load trained T5 model\n",
"import torch\n",
"model.model = T5ForConditionalGeneration.from_pretrained(\"./outputs/simplet5-epoch-4-train-loss-0.1129-val-loss-0.2851\")\n",
"model.device = torch.device(\"cpu\")\n",
"# predict\n",
"model.predict(\"现在呢孩子大了已经读初中了,我呢自己呢这个也思反思了。\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['类似的我们这里现场来开个会打开先问你用的是哪个是用的阵列还是用的哪个?']"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.predict(\"类似的我们这里现场来开个会哦诶打开了先问你用的是哪个是用的阵列还是用的哪个?\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
"\n",
"tokenizer = T5Tokenizer.from_pretrained(\"./mengzi-t5-base\")\n",
"model = T5ForConditionalGeneration.from_pretrained(\"./mengzi-t5-base\")"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"data_dir_source = \"./data/dis.raw.3k\"\n",
"data_dir_target = \"./data/flu.raw.3k\"\n",
"\n",
"with open(data_dir_source,\"r\",encoding=\"utf-8\") as f:\n",
" source = f.readlines()\n",
"\n",
"with open(data_dir_target,\"r\",encoding=\"utf-8\") as f:\n",
" target = f.readlines()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"with open(\"./data/raw.csv\",\"w\",encoding=\"utf-8\") as f:\n",
" f.write(\"source_text,target_text\\n\")\n",
" for i,j in zip(source,target):\n",
" sou = i.replace(\",\",\",\")\n",
" tar = j.replace(\",\",\",\")\n",
" f.write(sou.strip()+\",\"+tar)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>source_text</th>\n",
" <th>target_text</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>恩总体说来呢,对孩子的关心呢是很不够的。</td>\n",
" <td>总体说来,对孩子的关心是很不够的。</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>一般性在早上六点出去,要到晚上六点回来。</td>\n",
" <td>一般性在早上六点出去,要到晚上六点回来。</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>问他到什么地方去哪玩呢,都不肯说的。</td>\n",
" <td>问他到什么地方去哪玩,都不肯说的。</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>所以说呢他越不肯说呢,我心里想恼火,基基本上呢要打他的。</td>\n",
" <td>他越不肯说,我心里想恼火,基本上要打他的。</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>那么越打他呢,他越不肯说。</td>\n",
" <td>越打他,他越不肯说。</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2995</th>\n",
" <td>恩只要你上车的时候保持好警惕,注意好随身随身带的一些贵重物品啊,然后找一感不要太麻痹大意,我...</td>\n",
" <td>只要你上车的时候保持好警惕,注意好随身带的一些贵重物品,然后不要太麻痹大意,我想这小偷一般也...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2996</th>\n",
" <td>我当时抓了小偷的时候我感觉有点紧张啊,我想警车快点来,要不然话万一这小偷有同伙一个来一大堆人...</td>\n",
" <td>我当时抓了小偷的时候我感觉有点紧张,我想警车快点来,要不然万一这小偷有同伙来一大堆人来,我想...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2997</th>\n",
" <td>不过还好啦,小偷没有什么同伙。</td>\n",
" <td>不过还好,小偷没有同伙。</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2998</th>\n",
" <td>随着社会的发展,中学生上网日益普遍,有些家长认为上网很浪费时间,而有些家长却认为上网可以帮助...</td>\n",
" <td>随着社会的发展,中学生上网日益普遍,有些家长认为上网很浪费时间,而有些家长却认为上网可以帮助...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2999</th>\n",
" <td>在大多数家长面前,孩子上网都是躲着,我认为上网是弊大于利。</td>\n",
" <td>在大多数家长面前,孩子上网都是躲着,我认为上网是弊大于利。</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>3000 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" source_text \\\n",
"0 恩总体说来呢,对孩子的关心呢是很不够的。 \n",
"1 一般性在早上六点出去,要到晚上六点回来。 \n",
"2 问他到什么地方去哪玩呢,都不肯说的。 \n",
"3 所以说呢他越不肯说呢,我心里想恼火,基基本上呢要打他的。 \n",
"4 那么越打他呢,他越不肯说。 \n",
"... ... \n",
"2995 恩只要你上车的时候保持好警惕,注意好随身随身带的一些贵重物品啊,然后找一感不要太麻痹大意,我... \n",
"2996 我当时抓了小偷的时候我感觉有点紧张啊,我想警车快点来,要不然话万一这小偷有同伙一个来一大堆人... \n",
"2997 不过还好啦,小偷没有什么同伙。 \n",
"2998 随着社会的发展,中学生上网日益普遍,有些家长认为上网很浪费时间,而有些家长却认为上网可以帮助... \n",
"2999 在大多数家长面前,孩子上网都是躲着,我认为上网是弊大于利。 \n",
"\n",
" target_text \n",
"0 总体说来,对孩子的关心是很不够的。 \n",
"1 一般性在早上六点出去,要到晚上六点回来。 \n",
"2 问他到什么地方去哪玩,都不肯说的。 \n",
"3 他越不肯说,我心里想恼火,基本上要打他的。 \n",
"4 越打他,他越不肯说。 \n",
"... ... \n",
"2995 只要你上车的时候保持好警惕,注意好随身带的一些贵重物品,然后不要太麻痹大意,我想这小偷一般也... \n",
"2996 我当时抓了小偷的时候我感觉有点紧张,我想警车快点来,要不然万一这小偷有同伙来一大堆人来,我想... \n",
"2997 不过还好,小偷没有同伙。 \n",
"2998 随着社会的发展,中学生上网日益普遍,有些家长认为上网很浪费时间,而有些家长却认为上网可以帮助... \n",
"2999 在大多数家长面前,孩子上网都是躲着,我认为上网是弊大于利。 \n",
"\n",
"[3000 rows x 2 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"data = pd.read_csv(\"./data/raw.csv\")\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>source_text</th>\n",
" <th>target_text</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>恩总体说来呢,对孩子的关心呢是很不够的。</td>\n",
" <td>总体说来,对孩子的关心是很不够的。</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>一般性在早上六点出去,要到晚上六点回来。</td>\n",
" <td>一般性在早上六点出去,要到晚上六点回来。</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>问他到什么地方去哪玩呢,都不肯说的。</td>\n",
" <td>问他到什么地方去哪玩,都不肯说的。</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>所以说呢他越不肯说呢,我心里想恼火,基基本上呢要打他的。</td>\n",
" <td>他越不肯说,我心里想恼火,基本上要打他的。</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>那么越打他呢,他越不肯说。</td>\n",
" <td>越打他,他越不肯说。</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1495</th>\n",
" <td>恩军役已经去了就是在朝恩朝鲜的边境已经去去了一年多,恩希望他半年之后回来可以有呃继续拍一些恩...</td>\n",
" <td>就是在朝鲜的边境已经去了一年多,希望他半年之后回来可以继续拍一些好的片子。</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1496</th>\n",
" <td>恩然后韩国有很多恩很棒的歌手。</td>\n",
" <td>然后韩国有很多很棒的歌手。</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1497</th>\n",
" <td>嗯像flightsky嗯,还有那个HOT,虽然说他们几乎像是已经解散了,但是我认为他们永远是...</td>\n",
" <td>像flightsky,还有HOT,虽然他们几乎像是已经解散了,但是我认为他们永远是最棒的。</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1498</th>\n",
" <td>恩然后就是呃港台港台,因为现在港台的新生歌手就是一代接一代出了很多,但是我觉其中最可爱活泼的...</td>\n",
" <td>然后就是港台,现在港台的新生歌手一代接一代出了很多,但是我觉其中最可爱活泼的就是SHE。</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1499</th>\n",
" <td>我非常喜欢听她们的歌,她们已经出了七张七张专辑,恩从第一盘单身宿恩女生宿舍开始就是恩和现,其...</td>\n",
" <td>我非常喜欢听她们的歌,她们已经出了七张专辑,从第一盘女生宿舍开始其实和她们现在的声音和现在的...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1500 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" source_text \\\n",
"0 恩总体说来呢,对孩子的关心呢是很不够的。 \n",
"1 一般性在早上六点出去,要到晚上六点回来。 \n",
"2 问他到什么地方去哪玩呢,都不肯说的。 \n",
"3 所以说呢他越不肯说呢,我心里想恼火,基基本上呢要打他的。 \n",
"4 那么越打他呢,他越不肯说。 \n",
"... ... \n",
"1495 恩军役已经去了就是在朝恩朝鲜的边境已经去去了一年多,恩希望他半年之后回来可以有呃继续拍一些恩... \n",
"1496 恩然后韩国有很多恩很棒的歌手。 \n",
"1497 嗯像flightsky嗯,还有那个HOT,虽然说他们几乎像是已经解散了,但是我认为他们永远是... \n",
"1498 恩然后就是呃港台港台,因为现在港台的新生歌手就是一代接一代出了很多,但是我觉其中最可爱活泼的... \n",
"1499 我非常喜欢听她们的歌,她们已经出了七张七张专辑,恩从第一盘单身宿恩女生宿舍开始就是恩和现,其... \n",
"\n",
" target_text \n",
"0 总体说来,对孩子的关心是很不够的。 \n",
"1 一般性在早上六点出去,要到晚上六点回来。 \n",
"2 问他到什么地方去哪玩,都不肯说的。 \n",
"3 他越不肯说,我心里想恼火,基本上要打他的。 \n",
"4 越打他,他越不肯说。 \n",
"... ... \n",
"1495 就是在朝鲜的边境已经去了一年多,希望他半年之后回来可以继续拍一些好的片子。 \n",
"1496 然后韩国有很多很棒的歌手。 \n",
"1497 像flightsky,还有HOT,虽然他们几乎像是已经解散了,但是我认为他们永远是最棒的。 \n",
"1498 然后就是港台,现在港台的新生歌手一代接一代出了很多,但是我觉其中最可爱活泼的就是SHE。 \n",
"1499 我非常喜欢听她们的歌,她们已经出了七张专辑,从第一盘女生宿舍开始其实和她们现在的声音和现在的... \n",
"\n",
"[1500 rows x 2 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data[:1500]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "t5",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from simplet5 import SimpleT5
model = SimpleT5()
model_path = "./outputs/simplet5-epoch-4-train-loss-0.1129-val-loss-0.2851"
model.tokenizer = T5Tokenizer.from_pretrained(model_path)
model.model = T5ForConditionalGeneration.from_pretrained(model_path)
model.device = torch.device("cpu")
# model.load_model("t5","./outputs/simplet5-epoch-4-train-loss-0.1129-val-loss-0.2851", use_gpu=False)
test1 = "你这是要准备干嘛呀?我,我再进一个人啊你这不就一个人吗?"
test2 = "你这样不是会刺激吗?你这样。"
result = model.predict("现在呢孩子大了已经读初中了,我呢自己呢这个也思反思了。")
print(result)
result1 = model.predict(test1)
print(result1)
result2 = model.predict(test2)
print(result2)
\ No newline at end of file
curl -X POST "http://127.0.0.1:5555/disflu" \
-H 'Content-Type: application/json' \
-d '{"text": "现在呢孩子大了已经读初中了,我呢自己呢这个也思反思了。"}'
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment