import json
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score
import numpy as np
from typing import List, Dict, Tuple
import argparse
from rich.console import Console
from rich.progress import track
import random
import matplotlib.pyplot as plt
from datetime import datetime
import torch
import sys

console = Console()

# 固定所有随机种子
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, "..", ".."))

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from src.train_utils.train_faclens_linear import NFPDataset, load_model_data, train_epoch_multi_model, evaluate_multi_model, train_dataset_model, MultiDimAdapter, safe_rich_print, save_results
from src.train_utils.save_config import save_run_metadata

def collect_input_dims(
    feats_dir: str,
    layer_type: str,
    datasets: List[str],
    models: List[str],
    operate_type: str,
) -> List[int]:
    """收集所有特征的输入维度"""
    in_dims = set()

    for model in models:
        model_folder = model + f"_{operate_type}"
        for dataset in datasets:
            data_path = os.path.join(
                feats_dir, model_folder, dataset, layer_type, "train.pt"
            )
            if os.path.exists(data_path):
                try:
                    data = torch.load(data_path)
                    features = data["features"]
                    in_dims.add(int(features.shape[1]))
                    console.print(f"收集维度: {model}/{dataset} -> {features.shape[1]}")
                except Exception as e:
                    console.print(f"[red]加载失败 {data_path}: {e}[/red]")
                    continue

    return sorted(list(in_dims))


class gsm8k_linear(nn.Module):
    """MLP分类器，支持多维度输入，使用预注册的投影层"""

    def __init__(self, target_dim: int = 4096, in_dims: List[int] = None):
        super().__init__()
        self.target_dim = target_dim

        # 初始化多维度适配器
        if in_dims is None:
            in_dims = [target_dim]  # 默认只支持目标维度
        self.adapter = MultiDimAdapter(target_dim, in_dims)

        self.hidden = (512, 128, 32)

        def block(in_dim, out_dim):
            return nn.Sequential(
                nn.LayerNorm(in_dim),
                nn.Linear(in_dim, out_dim),
                nn.ReLU(inplace=True),
                # nn.Dropout(0.1),
            )

        self.encoder = nn.Sequential(
            block(target_dim, self.hidden[0]),
            block(self.hidden[0], self.hidden[1]),
            block(self.hidden[1], self.hidden[2]),
        )
        self.clf = nn.Linear(self.hidden[2], 2)

    def forward(self, x):
        if x.dtype != next(self.parameters()).dtype:
            x = x.to(next(self.parameters()).dtype)

        # 使用适配器处理不同维度的输入
        x = self.adapter(x)

        if x.dtype != next(self.parameters()).dtype:
            x = x.to(next(self.parameters()).dtype)

        features = self.encoder(x)
        logits = self.clf(features)
        return logits, features

def main():
    parser = argparse.ArgumentParser(description="为每个数据集训练FacLens分类器")
    parser.add_argument("--feats_dir", default=os.path.join(PROJECT_ROOT,"feats/gsm8k/gsm8k_feats_split"), help="特征目录")
    parser.add_argument("--output_dir", default=os.path.join(PROJECT_ROOT,"results/gsm8k/gsm8k_results"), help="输出目录")
    parser.add_argument("--device", default="cuda", help="设备")
    parser.add_argument("--layer_type", default="last", help="层类型", choices=["quarter", "middle", "three_quarters", "last", "second_last", "first", "all"])
    parser.add_argument("--datasets", nargs='+', default=["gsm8k_train"],
                       help="要训练的数据集列表")
    parser.add_argument("--batch_size", type=int, default=64, help="批量大小")
    parser.add_argument("--lr", type=float, default=1e-4, help="学习率")
    parser.add_argument("--weight_decay", type=float, default=0, help="权重衰减")
    parser.add_argument("--operate_type", default="last_token", help="操作类型", choices=["prompt_last_token", "answer_first_token", "last_token"])
    parser.add_argument("--level_filter", default=None, help="只处理指定level的数据，例如'Level 5'")
    parser.add_argument("--max_samples", type=int, default=None, help="最大训练样本数量，None表示使用全部样本")
    parser.add_argument("--filter_incomplete", type=bool, default=False, help="是否过滤pred_answer为'Incomplete'的样本")
    args = parser.parse_args()

    # 利用时间戳创建输出目录
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    args.output_dir = os.path.join(args.output_dir, timestamp)
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)

    # 模型和数据集组合
    # models = ["llama3.1-8b-instruct", "phi3.5-mini-4b-instruct", "granite-3.1-8b-instruct", "qwen2.5-7b-instruct", "qwen2.5-14b-instruct"]
    models = ["llama3.2-3b-instruct", "ministral-8b-instruct-2410", "falcon3-10b-instruct", "phi4-14b"]
    datasets = ["gsm8k"]

    # 收集所有输入维度
    console.print("[bold yellow]收集特征维度信息...[/bold yellow]")
    in_dims = collect_input_dims(args.feats_dir, args.layer_type, datasets, models, args.operate_type)
    console.print(f"收集到的特征维度: {in_dims}")

    # 加载模型数据
    models_data = load_model_data(args.feats_dir, args.layer_type, datasets, models, args.operate_type)

    # 为每个数据集训练模型
    console.print("[bold yellow]开始为每个数据集训练模型...[/bold yellow]")
    all_results = {}

    # 首次保存运行元信息的标志
    saved_meta = False

    for dataset_name in args.datasets:
        console.print(f"\n[bold cyan]======= 训练数据集: {dataset_name} =======[/bold cyan]")
        try:
            model = gsm8k_linear(target_dim=4096, in_dims=in_dims)

            # 第一次创建模型时，落盘本次运行的超参 & 模型结构
            if not saved_meta:
                meta_path = save_run_metadata(args, model, models, args.output_dir)
                console.print(f"[green]已保存本次运行元信息到: {meta_path}[/green]")
                saved_meta = True

            results = train_dataset_model(
                dataset_name, models_data, args.output_dir, args.device, models, args.batch_size, args.lr, args.weight_decay, args.level_filter, args.max_samples, args.filter_incomplete, model
            )
            all_results[dataset_name] = results
            console.print(f"[green]数据集 {dataset_name} 训练完成![/green]")
        except Exception as e:
            safe_rich_print(f"训练数据集 {dataset_name} 失败: {e}", "red")
            import traceback
            traceback.print_exc()

    save_results(args, all_results)

if __name__ == "__main__":
    main()
