logo

机器学习|从0开发大模型之数据预处理

本文主要介绍数据的预处理。

1、找大模型的数据

前面写了一篇文章《ChatGPT|大语言模型训练有哪些开源数据集? 》(mp.weixin.qq.com/s?__biz=MzA…
不过在开发大模型,需要根据实际的需求可以找到不同的数据,比如如果需要英文预料,那么就需要找到英文的预料,目前我们的 myllm 项目主要是中文小模型,所以找了一些中文相关数据:
如果需要其他数据可以在 huggingface 查找,地址:huggingface.co/datasets,我也…

2、数据预处理

下载数据以后,按照如下流程处理:
  • 提取文件的文本数据
  • 将文本数据进行截断,比如某段文本超过限制的上下文大小(如:512),就需要截断,增加截断标识
  • 将文本转换为token,格式化存储token数据
处理以下格式的数据:
  
  
  
  
  
  
[
{
"completion": "昭通机场(ZPZT)是位于中国云南昭通的民用机场,始建于1935年,1960年3月开通往返航班“昆明-昭通”,原来属军民合用机场。1986年机场停止使用。1991年11月扩建,于1994年2月恢复通航。是西南地区「文明机场」,通航城市昆明。 机场占地1957亩,飞行区等级为4C,有一条跑道,长2720米,宽48米,可供波音737及以下机型起降。机坪面积6600平方米,停机位2个,航站楼面积1900平方米。位于城东6公里处,民航路与金鹰大道交叉处。\n航点\n客服电话\n昭通机场客服电话:0870-2830004",
"source": "wikipedia.zh2307"
}
]
处理代码如下:
  
  
  
  
  
  
tokenizer = AutoTokenizer.from_pretrained('./my_tokenizer', use_fast=False)
basepath = "../datasets"
# 截断数据
def split_text(text, n = 512):
return [text[i: i + n] for i in range(0, len(text), n)]
# 整理wikipedia-cn-20230720-filtered数据,下载地址:https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered
def process_wiki_clean():
with open(f'{basepath}/wikipedia-cn-20230720-filtered.json', 'r', encoding='utf-8') as file:
data = ujson.loads(file.read())
data_len = len(data)
doc_ids = []
for idx, line in enumerate(data):
text_input = line['completion']
text_arr = split_text(text_input)
for text in text_arr:
text_id = tokenizer(f'{bos_token}{text}{eos_token}').data['input_ids']
print("text_id: ", text_id, ", text: ", text)
if len(text_id) > 5:
doc_ids += text_id
if idx % (int(data_len / 20)) == 0:
print(f"[{idx}/{data_len}] {text}")
arr = np.array(doc_ids, dtype=np.uint16)
with open(f'{basepath}/wikipedia-cn-20230720-filtered.bin', 'wb') as f:
f.write(arr.tobytes())
其中 text_id 输出就是从前面训练的Tokenizer中输出的对应的词ID,然后将 doc_ids 通过 numpy 序列化为 wikipedia-cn-20230720-filtered.bin 文件。

3、合并多个数据

可以将多个数据,代码如下:
  
  
  
  
  
  
# 将多个数据合并为一个文件
def pretrain_process():
process_wiki_clean()
data_path_list = [
f'{basepath}/wikipedia-cn-20230720-filtered.bin',
]
data_list = []
for data_path in data_path_list:
with open(data_path, 'rb') as f:
data = np.fromfile(f, dtype=np.uint16)
data_list.append(data)
arr = np.concatenate(data_list)
print(arr.shape)
with open(f'{basepath}/pretrain_data.bin', 'wb') as f:
f.write(arr.tobytes())
最后训练数据是 pretrain_data.bin,数据大小 361M

参考

(1)Wiki中文百科:
(2)天工数据集:huggingface.co/datasets/Sk…
————————————————
版权声明:本文为稀土掘金博主「周末程序猿」的原创文章
原文链接:https://juejin.cn/post/7432959779624992804
如有侵权,请联系千帆社区进行删除
评论
用户头像