import json
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score
import numpy as np
from typing import List, Dict, Tuple
import argparse
from rich.console import Console
from rich.progress import track
import random   
import matplotlib.pyplot as plt
from datetime import datetime
import torch
import sys

console = Console()

# 固定所有随机种子
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, "..", ".."))

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from src.train_utils.train_faclens_pad import load_model_data, train_dataset_model, pad_to_dim, NFPDataset, save_results, safe_rich_print
from src.train_utils.save_config import save_run_metadata

class mmlu_pro_pad(nn.Module):
    """FacLens 分类器，支持混合输入维度：4096 维使用 Identity，其余维度投影到 4096。"""

    def __init__(self, target_dim: int = 4096):
        super().__init__()
        self.target_dim = target_dim
        # self.input_norm = nn.LayerNorm(target_dim)
        self.hidden = (512, 128, 32)

        def block(in_dim, out_dim):
            return nn.Sequential(
                nn.LayerNorm(in_dim),
                nn.Linear(in_dim, out_dim),
                nn.ReLU(inplace=True),
                # nn.Dropout(0.1),
            )

        self.encoder = nn.Sequential(
                block(target_dim, self.hidden[0]),
                block(self.hidden[0], self.hidden[1]),
                block(self.hidden[1], self.hidden[2]),
            )
        
        self.clf = nn.Linear(32, 2)

    def forward(self, x):
        if x.dtype != next(self.parameters()).dtype:
            x = x.to(next(self.parameters()).dtype)

        features = self.encoder(x)
        logits = self.clf(features)
        return logits, features

def main():
    parser = argparse.ArgumentParser(description="为每个数据集训练FacLens分类器")
    parser.add_argument("--feats_dir", default=os.path.join(PROJECT_ROOT,"feats/mmlu_pro/mmlu_pro_feats_split"), help="特征目录")
    parser.add_argument("--output_dir", default=os.path.join(PROJECT_ROOT,"results/mmlu_pro/mmlu_pro_results"), help="输出目录")
    parser.add_argument("--device", default="cuda", help="设备")
    parser.add_argument("--layer_type", default="last", help="层类型", choices=["middle", "last", "second_last"])
    parser.add_argument("--datasets", nargs='+', default=["mmlu_pro"], 
                       help="要训练的数据集列表")
    parser.add_argument("--batch_size", type=int, default=128, help="批量大小")
    parser.add_argument("--lr", type=float, default=5e-4, help="学习率")
    parser.add_argument("--weight_decay", type=float, default=1e-5, help="权重衰减")
    parser.add_argument("--operate_type", default="last_token", help="操作类型", choices=["last_token", "avg_with_prompt", "avg_without_prompt"])
    parser.add_argument("--level_filter", default=None, help="只处理指定level的数据，例如'Level 5'")
    parser.add_argument("--max_samples", type=int, default=None, help="最大训练样本数量，None表示使用全部样本")
    parser.add_argument("--filter_incomplete", type=bool, default=False, help="是否过滤pred_answer为'Incomplete'的样本")
    args = parser.parse_args()
    
    # 利用时间戳创建输出目录
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    args.output_dir = os.path.join(args.output_dir, timestamp)
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    # 模型和数据集组合
    models = ["llama3.2-3b-instruct", "qwen2.5-7b-instruct", "ministral-8b-instruct-2410", "falcon3-10b-instruct", "qwen2.5-14b-instruct"]
    # models = ["internlm2.5-7b-chat", "qwen2.5-3b-instruct", "falcon3-10b-instruct", "llama3.1-8b-instruct", "phi4-14b"]
    datasets = ["mmlu_pro"]

    # 加载模型数据
    models_data = load_model_data(args.feats_dir, args.layer_type, datasets, models, args.operate_type)
    
    # 为每个数据集训练模型
    console.print("[bold yellow]开始为每个数据集训练模型...[/bold yellow]")
    all_results = {}

    # 首次保存运行元信息的标志
    saved_meta = False

    for dataset_name in args.datasets:
        console.print(f"\n[bold cyan]======= 训练数据集: {dataset_name} =======[/bold cyan]")
        try:
            model = mmlu_pro_pad()

            # 第一次创建模型时，落盘本次运行的超参 & 模型结构
            if not saved_meta:
                meta_path = save_run_metadata(args, model, models, args.output_dir)
                console.print(f"[green]已保存本次运行元信息到: {meta_path}[/green]")
                saved_meta = True

            results = train_dataset_model(
                model, dataset_name, models_data, args.output_dir, args.device, models, args.batch_size, args.lr, args.weight_decay, args.level_filter, args.max_samples, args.filter_incomplete
            )
            all_results[dataset_name] = results
            console.print(f"[green]数据集 {dataset_name} 训练完成![/green]")
        except Exception as e:
            safe_rich_print(f"训练数据集 {dataset_name} 失败: {e}", "red")
            import traceback
            traceback.print_exc()
    
    save_results(args, all_results)

if __name__ == "__main__":
    main()