import json
import os
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch.optim.lr_scheduler import CosineAnnealingLR
from datetime import datetime

os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"

# 配置路径
PROJECT_ROOT = Path(__file__).parent.parent
DATA_DIR = PROJECT_ROOT / "data"
BAYES_DIR = PROJECT_ROOT / "bayes"
MODEL_PATH = "/research/d7/gds/yhhan25/.cache/modelscope/hub/models/Qwen/Qwen3-14B"
OUTPUT_DIR = PROJECT_ROOT / "SFT" / "output"

class CustomDataset(Dataset):
    """自定义数据集类"""
   
    def __init__(self, data):
        self.data = data
   
    def __len__(self):
        return len(self.data)
   
    def __getitem__(self, idx):
        return self.data[idx]

def load_template():
    """加载代码模板"""
    template_path = PROJECT_ROOT / "main.py"
    print(f"加载代码模板: {template_path}")
    with open(template_path, 'r') as f:
        template = f.read()
    print(f"模板加载成功，长度: {len(template)} 字符")
    return template
 
def prepare_dataset():
    """准备训练数据集"""
    print("\n" + "="*80)
    print("准备训练数据集")
    print("="*80)
   
    template = load_template()
    dataset = []
   
    # 遍历data文件夹下的json文件
    json_files = list(DATA_DIR.glob("*.json"))
    print(f"\n🔍 在 {DATA_DIR} 中找到 {len(json_files)} 个JSON文件")
   
    for idx, json_file in enumerate(json_files, 1):
        # 构造对应的ground truth文件名
        script_name = json_file.stem + "_script.py"
        script_path = BAYES_DIR / script_name
       
        print(f"\n[{idx}/{len(json_files)}] 处理: {json_file.name}")
       
        if not script_path.exists():
            print(f"  ⚠️  跳过 - 未找到对应的ground truth: {script_name}")
            continue
       
        print(f"找到ground truth: {script_name}")
           
        # 读取输入和输出
        with open(json_file, 'r') as f:
            circuit_data = json.load(f)
        print(f"加载电路数据，包含 {len(circuit_data.get('cells', []))} 个cells")
       
        with open(script_path, 'r') as f:
            ground_truth = f.read()
        print(f"加载ground truth代码，长度 {len(ground_truth)} 字符")
       
        # 构造prompt和response
        prompt = f"""请根据以下模板和电路数据生成代码：
 
        模板代码：
        ```python
        {template}
        电路数据：
 
        json
        {json.dumps(circuit_data, indent=2)}
        要求生成完整的Python代码来处理该电路数据。具体来说，order_list和rotation_list应根据cells的数量进行调整，以实现合理的布局和旋转。rotation_list中的旋转选项包括'R0'和'MY'。只输出Python代码，不要包含任何解释、说明或思考过程。"""
 
        dataset.append({
                "prompt": prompt,
                "response": ground_truth
            })
 
        print(f"\n{'='*80}")
        print(f"数据集准备完成，共 {len(dataset)} 个训练样本")
        print(f"{'='*80}\n")
 
    return dataset
def format_data_for_training(dataset, tokenizer):
    """格式化数据用于训练"""
    print("格式化数据...")
    formatted_data = []
 
 
    for item in dataset:
        prompt = item["prompt"]
        response = item["response"]
       
        # 构造Qwen格式的对话
        text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>"
       
        # Tokenize
        tokenized = tokenizer(
            text,
            truncation=True,
            max_length=8196,
            padding="max_length",
            return_tensors="pt"
        )
       
        # 创建训练样本
        formatted_item = {
            "input_ids": tokenized["input_ids"].squeeze(),
            "attention_mask": tokenized["attention_mask"].squeeze(),
            "labels": tokenized["input_ids"].squeeze().clone()  # 因果语言建模，labels与input_ids相同
        }
       
        formatted_data.append(formatted_item)
 
    print("✓ 数据格式化完成")
    return formatted_data
class CustomTrainer:
    """自定义训练器"""
    def __init__(self, model, train_dataloader, optimizer, tokenizer, device, gradient_accumulation_steps=16):
        self.model = model
        self.train_dataloader = train_dataloader
        self.optimizer = optimizer
        self.tokenizer = tokenizer
        self.device = device
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.global_step = 0
       
    def train(self, num_epochs=3, save_steps=50):
        """训练循环"""
        self.model.train()
       
        for epoch in range(num_epochs):
            print(f"\n🎯 开始第 {epoch + 1}/{num_epochs} 轮训练")
           
            total_loss = 0
            self.optimizer.zero_grad()
           
            for step, batch in enumerate(self.train_dataloader):
                # 将数据移动到设备
                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)
                labels = batch["labels"].to(self.device)
               
                # 前向传播
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
               
                loss = outputs.loss
                loss = loss / self.gradient_accumulation_steps  # 梯度累积
               
                # 反向传播
                loss.backward()
               
                total_loss += loss.item()
               
                # 梯度累积步骤
                if (step + 1) % self.gradient_accumulation_steps == 0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    self.global_step += 1
                   
                    # 保存检查点
                    if self.global_step % save_steps == 0:
                        self.save_checkpoint()
           
            # 每轮结束后保存
            self.save_checkpoint(epoch=epoch + 1)
 
    def save_checkpoint(self, epoch=None):
        """保存检查点"""
        if epoch:
            save_path = OUTPUT_DIR / f"checkpoint-epoch-{epoch}"
        else:
            save_path = OUTPUT_DIR / f"checkpoint-step-{self.global_step}"
       
        save_path.mkdir(parents=True, exist_ok=True)
       
        # 保存模型和tokenizer
        self.model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)
       
        print(f"💾 检查点已保存至: {save_path}")
 
def train():
    """训练模型"""
    print("开始SFT训练流程")
 
    # 加载模型和tokenizer
    print(f"\n🤖 加载模型: {MODEL_PATH}")
    print("   这可能需要几分钟...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True, trust_remote_code=True)
    print("✓ Tokenizer加载完成")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16
    )
    model = AutoModelForCausalLM.from_pretrained(
            MODEL_PATH,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True,
            local_files_only=True
    )
    print("✓ 模型加载完成")
 
        # LoRA配置
    print("\n🔧 配置LoRA参数...")
    lora_config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
    )
    print(f"   LoRA rank: {lora_config.r}")
    print(f"   LoRA alpha: {lora_config.lora_alpha}")
    print(f"   Target modules: {lora_config.target_modules}")
 
    model = get_peft_model(model, lora_config)
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    all_params = sum(p.numel() for p in model.parameters())
    print(f"✓ LoRA模型创建完成")
    print(f"   可训练参数: {trainable_params:,} ({100 * trainable_params / all_params:.2f}%)")
    print(f"   总参数: {all_params:,}")
 
    # 准备数据
    raw_dataset = prepare_dataset()
    formatted_dataset = format_data_for_training(raw_dataset, tokenizer)
 
        # 创建数据加载器
    train_dataset = CustomDataset(formatted_dataset)
    train_dataloader = DataLoader(
            train_dataset,
            batch_size=1,
            shuffle=True
    )
 
        # 配置优化器
    print("\n⚙️  配置优化器...")
    optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=2e-4,
            weight_decay=0.01
    )
 
        # 训练
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"   使用设备: {device}")
 
    trainer = CustomTrainer(
            model=model,
            train_dataloader=train_dataloader,
            optimizer=optimizer,
            tokenizer=tokenizer,
            device=device,
            gradient_accumulation_steps=16
    )
    
    print("开始训练...")
 
    trainer.train(num_epochs=3, save_steps=50)
 
    # 保存最终模型
    print("保存最终模型...")
    output_path = OUTPUT_DIR / "final_model"
    output_path.mkdir(parents=True, exist_ok=True)
    model.save_pretrained(output_path)
    tokenizer.save_pretrained(output_path)
    print(f"✓ 模型已保存至: {output_path}")
 
    print("训练完成！")
 

train()