import os
import argparse
import math
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, PreTrainedTokenizerFast
from tqdm import tqdm
import numpy as np

# 尝试导入 MiniMind 模型
try:
    from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
except ImportError:
    import sys
    sys.path.append(os.getcwd())
    from model.model_minimind import MiniMindConfig, MiniMindForCausalLM

# ==========================================
# 1. 数据集类
# ==========================================
class LocalJsonlDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_len=512):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.samples = []
        
        print(f"Loading data from {file_path}...")
        
        if not os.path.exists(file_path):
            # 调试信息：如果找不到文件，打印一下 dataset 目录里到底有啥
            parent_dir = os.path.dirname(file_path)
            if os.path.exists(parent_dir):
                print(f"Error: File not found. Contents of {parent_dir}:")
                print(os.listdir(parent_dir))
            else:
                print(f"Error: Directory {parent_dir} does not exist.")
            raise FileNotFoundError(f"Data file not found at: {file_path}")

        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if not line.strip(): continue
                try:
                    obj = json.loads(line)
                    # 兼容不同字段名
                    text = obj.get('text', obj.get('content', ''))
                    if text: self.samples.append(text)
                except: pass
        print(f"Loaded {len(self.samples)} samples.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        text = self.samples[idx]
        enc = self.tokenizer(
            text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        input_ids = enc['input_ids'].squeeze(0)
        return input_ids, input_ids

# ==========================================
# 2. 训练主程序
# ==========================================
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, choices=["dense", "moe"], required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--data_path", type=str, default="./dataset/pretrain_data.jsonl")
    
    # 架构参数
    parser.add_argument("--hidden_size", type=int, default=512)
    parser.add_argument("--moe_inter_dim", type=int, default=1024)
    
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--max_steps", type=int, default=-1)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=5e-4)
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    # ------------------------------------------------
    # A. 加载 Tokenizer (本地)
    # ------------------------------------------------
    print("Loading Tokenizer from local ./model directory...")
    tokenizer_path = "./model" 
    
    try:
        tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)
    except Exception as e:
        print(f"Error loading tokenizer: {e}")
        # 备用方案：尝试直接加载 json
        tokenizer = PreTrainedTokenizerFast(tokenizer_file="./model/tokenizer.json")

    # 补全特殊 token
    if tokenizer.pad_token is None: tokenizer.pad_token = "<|endoftext|>"
    if tokenizer.eos_token is None: tokenizer.eos_token = "<|im_end|>"
    if tokenizer.bos_token is None: tokenizer.bos_token = "<|im_start|>"
    if tokenizer.unk_token is None: tokenizer.unk_token = "<|endoftext|>"

    # 【关键修改】动态获取词表大小，不再硬编码
    vocab_size = len(tokenizer)
    print(f"Tokenizer loaded. Vocab size: {vocab_size}")

    # ------------------------------------------------
    # B. 配置模型
    # ------------------------------------------------
    common_args = {
        "hidden_size": args.hidden_size,
        "num_hidden_layers": 8,
        "num_attention_heads": 8,
        "vocab_size": vocab_size,  # 使用真实的词表大小 (6400)
        "max_position_embeddings": 2048,
        "dropout": 0.0,
        "flash_attn": False
    }

    if args.mode == "moe":
        config = MiniMindConfig(
            **common_args,
            use_moe=True,
            n_routed_experts=4,
            n_shared_experts=1,
            num_experts_per_tok=2,
            intermediate_size=args.moe_inter_dim
        )
    else:
        config = MiniMindConfig(
            **common_args,
            use_moe=False,
            intermediate_size=args.moe_inter_dim * 5
        )

    print(f"[{args.mode.upper()}] Initializing Model...")
    model = MiniMindForCausalLM(config).cuda()
    
    total = sum(p.numel() for p in model.parameters())
    print(f"[{args.mode.upper()}] Total Params: {total/1e6:.2f}M")

    # ------------------------------------------------
    # C. 加载数据
    # ------------------------------------------------
    full_ds = LocalJsonlDataset(args.data_path, tokenizer, max_len=512)
    
    train_size = int(0.99 * len(full_ds))
    test_size = len(full_ds) - train_size
    train_ds, test_ds = torch.utils.data.random_split(full_ds, [train_size, test_size])
    
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=4)

    # ------------------------------------------------
    # D. 训练循环
    # ------------------------------------------------
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
    
    logs = []
    global_step = 0
    
    if args.max_steps > 0:
        total_epochs = 999999
        print(f"Training for {args.max_steps} steps...")
    else:
        total_epochs = args.epochs
        print(f"Training for {args.epochs} epochs...")

    print("Starting Training Loop...")
    model.train()
    
    training_finished = False

    for epoch in range(total_epochs):
        if training_finished: break
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        
        for inputs, _ in pbar:
            inputs = inputs.cuda()
            optimizer.zero_grad()
            
            outputs = model(input_ids=inputs)
            logits = outputs.logits
            
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = inputs[..., 1:].contiguous()
            
            ce_loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            total_loss = ce_loss + outputs.aux_loss
            
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            global_step += 1
            
            if global_step % 50 == 0:
                logs.append({
                    "step": global_step,
                    "type": "train",
                    "loss": ce_loss.item(),
                    "aux_loss": outputs.aux_loss.item() if hasattr(outputs, 'aux_loss') else 0
                })
                pbar.set_postfix({"ce": f"{ce_loss.item():.4f}", "step": global_step})

            if global_step % 250 == 0:
                model.eval()
                test_loss = 0
                count = 0
                with torch.no_grad():
                    for i, (t_in, _) in enumerate(test_loader):
                        if i >= 50: break
                        t_in = t_in.cuda()
                        t_out = model(input_ids=t_in)
                        s_logits = t_out.logits[..., :-1, :].contiguous()
                        s_labels = t_in[..., 1:].contiguous()
                        l = criterion(s_logits.view(-1, s_logits.size(-1)), s_labels.view(-1))
                        test_loss += l.item()
                        count += 1
                
                avg_test = test_loss / (count + 1e-9)
                logs.append({"step": global_step, "type": "test", "loss": avg_test})
                
                with open(os.path.join(args.output_dir, "training_logs.json"), "w") as f:
                    json.dump(logs, f)
                
                model.train()

            if args.max_steps > 0 and global_step >= args.max_steps:
                training_finished = True
                break

    torch.save(model.state_dict(), os.path.join(args.output_dir, "model_final.pth"))
    print(f"Training Finished at step {global_step}.")

if __name__ == "__main__":
    main()