1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
| from unsloth import FastLanguageModel from local_dataset import LocalJsonDataset from safetensors.torch import load_model, save_model
max_seq_length = 2048 dtype = None load_in_4bit = False model, tokenizer = FastLanguageModel.from_pretrained( model_name="./model/Qwen2-7B", max_seq_length=max_seq_length, dtype=dtype, load_in_4bit=load_in_4bit, )
model = FastLanguageModel.get_peft_model( model, r = 16, target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",], lora_alpha = 16, lora_dropout = 0, bias = "none", use_gradient_checkpointing = "unsloth", random_state = 3407, use_rslora = False, loftq_config = None, )
custom_dataset = LocalJsonDataset(json_file='train_data.json', tokenizer=tokenizer, max_seq_length=max_seq_length) dataset = custom_dataset.get_dataset()
from trl import SFTTrainer from transformers import TrainingArguments from unsloth import is_bfloat16_supported
trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=dataset, dataset_text_field="text", max_seq_length=max_seq_length, dataset_num_proc=2, args=TrainingArguments( per_device_train_batch_size=4, gradient_accumulation_steps=8, warmup_steps=20, max_steps=2000, learning_rate=5e-5, fp16=not is_bfloat16_supported(), bf16=is_bfloat16_supported(), logging_steps=1, optim="adamw_8bit", weight_decay=0.01, lr_scheduler_type="linear", seed=3407, output_dir="outputs", ), )
trainer.train() model.save_pretrained("lora_model") tokenizer.save_pretrained("lora_model")
FastLanguageModel.for_inference(model)
def generate_answer(question): input_text = f"下面列出了一个问题. 请写出问题的答案.\n####问题:{question}\n####答案:" inputs = tokenizer( [input_text], return_tensors="pt", padding=True, truncation=True ).to("cuda") outputs = model.generate(**inputs, max_new_tokens=2048, use_cache=True) decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] return decoded_output.split('<|im_end|>')[0].strip()
print("请输入您的问题,输入'exit'退出:") while True: user_input = input("> ") if user_input.lower() == 'exit': print("程序已退出。") break answer = generate_answer(user_input) print("---") print(answer)
|