在Codelab对llama3做Lora Fine tune微调

Unsloth 高效微调大模型的工具,通过Unsloth微调Llama3, Mistral, Gemma 速度提升2-5倍,内存减少70%!

Codelab 创建一个jupyter notebook

在这里插入图片描述
选择 T4 GPU
在这里插入图片描述
安装Fine tune 相关的lib

%%capture
import torch
major_version, minor_version= torch.cuda.get_device_capability()
# Must install separately since Colab has torch 2.2.1, which breaks packages
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
if major_version >= 8:
  # Use this for new GPs like Ampere, Hopper GPUs(RTX 30xx. RIX 40xx, A100. H100. L40)
  !pip install -no-deps packaging ninja einops flash-attn xformers trl peft accelerate bitsandbytes
else:
  # Use this for older GPUs (V100, Tesla T4, RTX 20xx)
  !pip install --no-deps xformers trl peft accelerate bitsandbytes
pass

下载llama3

from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False

# 4bit pre quantized models we support for 4x faster downloading + no OOMs
fourbit_models = [
    "unsloth/mistral-7b-bnb-4bit",
    "unsloth/mistral-7b-instruct-bnb-4bit",
    "unsloth/llama-2-7b-bnb-4bit",
    "unsloth/gemma-7b-bnb-4bit",
    "unsloth/gemma-7b-it-bnb-4bit",
    "unsloth/gemma-2b-bnb-4bit",
    "unsloth/gemma-2b-it-bnb-4bit",
    "unsloth/llama-3-8b-bnb-4bit",
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf

)

在这里插入图片描述

model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none", # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False, # We support rank stabilized LoRA
    loftq_config = None # And LoftQ
)

在这里插入图片描述

加载hugging face数据集

alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{}
### Input:
{}
### Response:
{}
"""

EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
  instructions = examples["instruction"]
  inputs = examples["input"]
  outputs = examples["output"]
  texts = []
  for instruction, input, output in zip(instructions, inputs, outputs):
    # Must add EOS_TOKEN, otherwise your generation will go on forever!
    text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
    texts.append(text)
  return { "text": texts, }
pass

from datasets import load_dataset
dataset = load_dataset("pinzhenchen/alpaca-cleaned-zh", split="train")
dataset = dataset.map(formatting_prompts_func, batched=True,)



在这里插入图片描述
HuggingFace 官网, 点击数据集 Datasets

在这里插入图片描述
搜索数据集 alpaca-cleaned-zh
在这里插入图片描述
复制数据集的名字 pinzhenchen/alpaca-cleaned-zh
在这里插入图片描述
定义training 方法

from trl import SFTTrainer
from transformers import TrainingArguments

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 60,
        learning_rate = 2e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
    ),
)

打印显存使用情况

#@title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = (gpu_stats.name). Max memory = (max_memory) GB.")
print(f"(start_gpu_memory) GB of memory reserved.")

在这里插入图片描述
开始FineTune

trainer_stats = trainer.train()

在这里插入图片描述

#@title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory*100, 3)
lora_percentage = round(used_memory_for_lora / max_memory*100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} GB.")

在这里插入图片描述
用fineTune 过的model,做问答

# alpaca_prompt = Copied from above
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
    [
        alpaca_prompt.format(
            "如何保持健康", # instruction
            "", # input
            "", # output - leave this blank for generation!
        )
    ], return_tensors = "pt"
).to("cuda")

outputs = model.generate(**inputs, max_new_tokens = 64, use_cache=True)
tokenizer.batch_decode(outputs)

在这里插入图片描述
TextStreamer 流式一个字一个字地打印结果

# alpaca_prompt = Copied from above
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
    [
        alpaca_prompt.format(
            "续写这段话", # instruction
            "天天向上,好好学习", # input
            "", # output - leave this blank for generation!
        )
    ], return_tensors = "pt"
).to("cuda")

from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens=128)

在这里插入图片描述
保存model到google drive 和 HuggingFace

model.save_pretrained("lora_model") # local saving
model.push_to_hub("zgpeace/lora_model", token="####") # online saving

google drive
在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/597253.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

等保测评—Linux-CentOS标准范例截图

密码输入错误无法登录 用户账户情况包含root、guanli、shenji 查看审计用户权限 身份鉴别: cat /etc/passwd,核查用户名和 UID,是否存在同样的用户名和 UID cat /etc/shadow,查看文件中各用户名状态 , 核查密码一栏为…

day1Qt作业

#include "mywidget.h"MyWidget::MyWidget(QWidget *parent): QWidget(parent) {this->resize(540,415);//窗口大小this->setFixedSize(540,415);//固定窗口大小this->setWindowTitle("QQ");//标题this->setWindowIcon(QIcon("E:\\hqyjap…

获取转转数据,研究完转转请求,tx在算法方面很友好。

本篇文章仅供学习讨论。 文章中涉及到的代码、实例,仅是个人日常学习研究的部分成果。 如有不当,请联系删除。 在研究完阿里的算法以后(其实很难说研究完,还有很多内容没有研究透,只能说暂时告一段落)&…

【CTF Web】XCTF GFSJ0475 get_post Writeup(HTTP协议+GET请求+POST请求)

get_post X老师告诉小宁同学HTTP通常使用两种请求方法,你知道是哪两种吗? 解法 用 Postman 发送一个 GET 请求,提交一个名为a,值为1的变量。 http://61.147.171.105:65402/?a1用 Postman 发送一个 POST 请求,提交一个名为b,值为…

Embeddings原理、使用方法、优缺点、案例以及注意事项

Embeddings是一种将高维数据映射到低维空间的技术,常用于处理自然语言处理(NLP)和计算机视觉(CV)任务。Embeddings可以将复杂的高维数据转换为低维稠密向量,使得数据可以更容易地进行处理和分析。本文将介绍…

国内首发 | CSA大中华区启动《AI安全产业图谱(2024)》调研

在人工智能(AI)技术的快速发展浪潮中,AI安全已成为全球关注的焦点。为应对AI安全带来的挑战,确保AI技术的健康发展,全球范围内的研究机构、企业和技术社区都在积极探索解决方案。 在这一背景下,CSA大中华区…

在2G到4g小区重选过程中,4g频点没有优先级信息,最后UE无法重选到4g,是否正常?

这个确实是老问题了,要翻开GSM 的协议找答案。 GSM cell reselection算法分为cell ranking based和priority based两种方式。cell ranking based 只能从GSM重选到UTRAN;而priority based则可以重选到UTRAN和EUTRA。 根据priority based重选算法的描述&am…

【WP】第一届 “帕鲁杯“ - CTF挑战赛 Web 全解

Web Web-签到 考点:审计py代码 from flask import Flask, request, jsonify import requests from flag import flag # 假设从 flag.py 文件中导入了 flag 函数 app Flask(__name__)app.route(/, methods[GET, POST]) def getinfo():url request.args.get(url)i…

java08基础(值传递和引用传递 类和对象)

目录 一. 值传递和引用传递 1. 值传递 2. 引用传递 二. 面向对象思想 三. 类和对象 1. 类 2. 对象 2.1 使用 2.2 成员变量和局部变量区别 2.3 操作成员方法 2.4 this关键字(初始) 2.5 构造方法 (见java09) 一. 值传递和引用传递 1. 值传递 值传递是指在调用函数时将…

webpack4和webpack5区别1---loader

webpack4处理图片和字体的loader file-loader file-loader的作用是处理webpack中的静态资源文件。File Loader可以将各种类型的文件,如图像、字体、视频等转换为模块并加载到Web应用程序中。它通过import或require语句引入文件资源,并将其放置在输出目…

多链路聚合设备是什么

多链路聚合设备属于通信指挥装备。 乾元通多链路聚合设备,它能够将多个网络链路聚合成一个逻辑链路,以实现高速、稳定、可靠的数据传输。多链路聚合设备的核心技术包括链路聚合、负载均衡、故障切换等,能够智能管理和优化利用不同网络链路&a…

概率论 科普

符号优先级 概率公式中一共有三种符号:分号 ; 、逗号 , 、竖线 | 。 ; 分号代表前后是两类东西,以概率P(x;θ)为例,分号前面是x样本,分号后边是模型参数。分号前的 表示的是这个式子用来预测分布的随机变量x,分号后的…

第一篇:刚接触测试你应该知道什么

欢迎你接触软件测试这一行业。 刚接触它时,你肯定或多或少会有疑惑,我该做什么?大家口口相传的软件测试就是 【点点点】 真的是你日常的工作吗? 那么本文我将陪你一起,对我们刚接触到测试这个工作以后,应该…

为什么电子商务安全是速度和保护之间的平衡行为

微信搜索关注公众号网络研究观,获取更多信息。 电子商务世界是一把双刃剑。虽然它为企业和消费者提供了便利和可访问性,但它也为网络犯罪分子提供了诱人的目标。在这个不断变化的环境中,优先考虑安全不再是一种选择;而是一种选择&…

nginx--tcp负载均衡

mysql负载均衡 安装mysql yum install -y mariadb-server systemctl start mariadb systemctl enable mariadb ss -ntl创建数据库并授权 MariaDB [(none)]> create database wordpress; Query OK, 1 row affected (0.00 sec)MariaDB [(none)]> grant all privileges o…

使用 electron-vite-vue 构建 electron + vue3 项目并打包

文章目录 一、使用 electron-vite-vue 构建 Vue3 项目1、创建项目并安装相关依赖2、安装依赖时报错 (operation not permitted) 二、项目打包1、执行打包命令2、下载失败处理3、手动方式下载后,将文件放至指定路径下4、打包成功后 参考资料 一、使用 electron-vite-…

C++证道之路第十八章探讨C++新标准

一、复习前面介绍过的C11新功能 1、新类型 C11新增了类型long long 和unsigned long long 新增了类型char16_t 和char32_t 2、统一的初始化 C11扩大了用大括号括起的列表(初始化列表)的使用范围,使其可以用于所有内置类型和用户定义的类…

JavaScript练习

1.冒泡排序 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title> </head>…

设备树与/sys/bus/platform/devices与/sys/devices目录关系

设备树与sys/bus/platform/devices sysfs文件系统中/sys/bus/platform/devices下的设备是由设备树生成&#xff0c; 根节点下有compatible的子节点都会在/bus/platform/devices生成节点 总线 I2C、SPI 等控制器会在/bus/platform/devices生成节点 总线 I2C、SPI 节点下的子节点…

ESP8266做主机 手机网络助手为从机

ATCIFSR查看地址&#xff0c;一般ESP8266 为192.168.4.1 在手机上下载网络调试助手&#xff0c;打开TCP客户端 创建后192.168.4.1 端口8089然后连接ESP8266热点。 ESP向手机发数据前先发送要发几个数据ATCIPSEND0,8表示发8个&#xff0c;然后再发8个数 上面创建好热点后&…
最新文章