import logging
from operator import ge
from pyexpat import model
from sre_parse import State
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn.parameter import Parameter
from tqdm import tqdm
from copy import deepcopy
from typing import Dict, Iterator, List, Optional
from src.datasets.common import maybe_dictionarize
from src.route_merged_model import RouteMergedModel
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
import os

log = logging.getLogger(__name__)
# 简单定义 StateDict 为 dict 类型
StateDict = dict

def draw_distribution(
    task_vector: StateDict,
    merged_state_dict: StateDict,
    pretrained_model: nn.Module,
    dataloader: Optional[torch.utils.data.DataLoader],
    device: str,
    dataset_name: str,
    type: str,
):
    SAVE_ROUTE = '/home/pzc/OT_fusion/results/' + type
    """
    使用T-SNE画出SFT模型和融合模型的feature distribution，在一张图上呈现
    """
    # 1. 构建模型
    model_SFT = build_model(pretrained_model, task_vector, device)
    model_merged = deepcopy(pretrained_model)
    model_merged.load_state_dict(merged_state_dict)
    model_merged = model_merged.to(device)

    # 2. 获取特征
    features_SFT = []
    features_merged = []
    
    for i, data in enumerate(pbar := tqdm(dataloader, leave=False)):
        data: Dict[str, Tensor] = maybe_dictionarize(data)
        x = data["images"].to(device)
        y = data["labels"].to(device)

        # 获取SFT模型的特征
        with torch.no_grad():
            feature_SFT = model_SFT(x).cpu()
            features_SFT.append(feature_SFT)

        # 获取融合模型的特征
        with torch.no_grad():
            feature_merged = model_merged(x).cpu()
            features_merged.append(feature_merged)
        
        if i >= 200:
            break
    
    # 合并所有特征
    features_SFT = torch.cat(features_SFT, dim=0).numpy()
    features_merged = torch.cat(features_merged, dim=0).numpy()

    # 3. 使用T-SNE降维
    features_combined = np.concatenate([features_SFT, features_merged], axis=0)
    tsne = TSNE(n_components=2, random_state=42)
    features_tsne = tsne.fit_transform(features_combined)

    # 4. 获取标签（0表示SFT模型，1表示融合模型）
    labels = np.concatenate([np.zeros(features_SFT.shape[0]), np.ones(features_merged.shape[0])], axis=0)

    # 5. 绘制散点图
    plt.figure(figsize=(6, 5))
    plt.scatter(features_tsne[labels == 0, 0], features_tsne[labels == 0, 1], c='blue', label='SFT Model', alpha=0.6)
    plt.scatter(features_tsne[labels == 1, 0], features_tsne[labels == 1, 1], c='red', label='Merged Model', alpha=0.6)
    
    # 设置标题和坐标轴标签的字体大小
    plt.title(f'{dataset_name}', fontsize=20)  # 使用dataset_name作为标题，并增大字体

    #增大坐标轴刻度字体
    plt.xticks(fontsize=18)  # 增大x轴刻度字体
    plt.yticks(fontsize=18)  # 增大y轴刻度字体
    # 增大图例样式的字体
    plt.legend(fontsize=18)  # 增大图例的字体

    # 6. 保存图像
    plt.savefig(os.path.join(SAVE_ROUTE, f'{dataset_name}_feature_distribution.png'))
    plt.close()
    print(1111111111111111)


def build_model(
    pretrained_model: nn.Module,
    task_state_dict: StateDict,
    device: str = "cuda:1",
):
    """
    根据预训练模型 + 一个任务的向量，构建对应的模型。
    通常做法是把 pretrained_model 的 state_dict 复制，然后加上 task_state_dict。
    """
    model = deepcopy(pretrained_model)
    model_sd = model.state_dict()
    
    for n, p in task_state_dict.items():
        if n in model_sd:
            # 直接加上去，或者看实际需求：是否相当于 residual 的方式
            model_sd[n] = model_sd[n].to(device) + p.to(device)
    
    model.load_state_dict(model_sd)
    model = model.to(device)
    return model
