【千帆SDK】使用文生图数据集进行模型微调生成Pokemon风格图片
大模型开发/技术交流
- LLM
- API
- 大模型训练
4月15日5285看过
💡学习前小提示
请大家点击链接并加🌟:https://github.com/baidubce/bce-qianfan-sdk
在 trainer 发起 finetune 中,我们已经学习了如何使用
trainer
+dataset
发起文生文微调任务,同时也体验了模型评估,批量推理,服务部署等流程;除了纯文本的生成模型外,千帆平台也提供了针对文心一格,以及开源的StableDiffusion模型的训练微调。
本例将基于qianfan==0.3.6.1展示通过Dataset加载本地数据集,并上传到千帆平台,基于
Stable-Diffusion
进行fine-tune,以实现pokemon风格的图片生成能力。
! pip install "qianfan[dataset_base]" -U! pip install datasets==2.14.6 # huggingface datasets! pip install fsspec==2023.9.2 # fix load_dataset error
import qianfanqianfan.__version__
前置准备
-
初始化千帆安全认证AK、SK
import osos.environ["QIANFAN_ACCESS_KEY"] = "your_ak"os.environ["QIANFAN_SECRET_KEY"] = "your_sk"
导入依赖
-
qianfan.trainer.consts
trainer使用中所用到的常量 -
qianfan.resources.console.consts
api层面定义的字段常量 -
qianfan.trainer.configs
trainer使用所需要的config配置数据类 -
qianfan.resources.QfMessages
用于组装qianfan.ChatCompletion的输入messages -
qianfan.trainer.finetune.Finetune
大语言模型fine-tune任务Trainer实现 -
qianfan.dataset.Dataset
千帆dataset类,用于管理千帆平台、本地、第三方数据集的导入导出,数据清洗等操作
from qianfan.trainer.consts import ActionStatefrom qianfan.model.consts import ServiceTypefrom qianfan.resources.console import consts as console_constsfrom qianfan.trainer.configs import TrainConfigfrom qianfan.model.configs import DeployConfigfrom qianfan.resources import QfMessagesfrom qianfan.trainer.finetune import Finetunefrom qianfan.dataset import Datasetfrom qianfan.utils import enable_logimport loggingenable_log(logging.INFO)
我们此次选用huggingface的开源数据集用于生成pokemon风格的图片
# 从huggingface 导入数据集:import datasetsdataset = datasets.load_dataset("svjack/pokemon-blip-captions-en-zh", split='train')
dataset.column_names# 输出:['image', 'en_text', 'zh_text']
将huggingface上的文生图数据集,增加指令,并转存成本地的数据集目录
import osimport jsonpokemon_style_instruction = "pokemon,"save_ds_dir = "./pokemon_ds"if not os.path.exists(save_ds_dir):os.mkdir(save_ds_dir)for i, v in enumerate(dataset):v["image"].save(f"{save_ds_dir}/{i}.jpg")with open(f"{save_ds_dir}/{i}.json", "w") as f:json.dump({"prompt": f'{pokemon_style_instruction} {v["en_text"]}'}, f)
数据集加载
千帆SDK提供了数据集实现帮助我们可以快速的加载本地的数据集到内存,并通过设定DataSource数据源以保存至本地和千帆平台。
from qianfan.dataset import Datasetfrom qianfan.dataset.data_source import FileDataSourcefrom qianfan.dataset.data_source.base import FormatTypefile_data_source = FileDataSource(path=save_ds_dir, file_format=FormatType.Text2Image)ds = Dataset.load(file_data_source)print(ds.list(0))
从本地数据集上传到BOS
# 保存到千帆平台from qianfan.dataset.data_source import QianfanDataSourcefrom qianfan.resources.console import consts as console_constsbos_bucket_name = "sdk-test"bos_bucket_file_path = "/sdk_ds/"qianfan_dataset_name = "random_sdk_train_t2i"# 创建千帆数据集,并上传保存qianfan_data_source = QianfanDataSource.create_bare_dataset(name=qianfan_dataset_name,template_type=console_consts.DataTemplateType.Text2Image,storage_type=console_consts.DataStorageType.PrivateBos,storage_id=bos_bucket_name,storage_path=bos_bucket_file_path,)ds = ds.save(qianfan_data_source)
发起图生文训练
这里我们选用
Stable-Diffusion-XL-Base-1.0
作为基础模型,
from qianfan.trainer.consts import PeftTypetrainer = Finetune(train_type="Stable-Diffusion-XL-Base-1.0",train_config=TrainConfig(peft_type=PeftType.LoRA,batch_size=8,epoch=20,learning_rate=0.00005,),dataset=ds,)
运行任务
同步运行trainer,训练直到模型发布完成
trainer.run()
获取finetune任务输出:
trainer.output
使用sdk发起部署流程,这一步需要到前端控制台进行支付才能完成:
#-# cell_skipfrom qianfan.model import Service, Modelfrom qianfan.model.consts import ServiceTypefrom qianfan.resources.console.consts import DeployPoolType# 从训练结果中获取模型对象m: Model = trainer.output["model"]sft_svc: Service = m.deploy(DeployConfig(name="random_t2i_sdk1",endpoint_prefix="sdpoke1",replicas=1, # 副本数, 与qps强绑定pool_type=DeployPoolType.PrivateResource, # 私有资源池service_type=ServiceType.Text2Image,))
使用Finetune之后的模型服务调用:
#-# cell_skipfrom qianfan.resource import Text2Image### 使用Model & Service调用模型problem="pokemon, a blue monkey with a hat"#获取服务对象,即ChatCompletion等类型的对象t2i: qianfan.Text2Image = sft_svc.get_res()from PIL import Imageimport ioresp = t2i.do(prompt=problem, with_decode="base64")img_data = resp["body"]["data"][0]["image"]img = Image.open(io.BytesIO(img_data))display(img)
评论