from heapq import merge
import logging
from operator import ge
from pyexpat import model
from sre_parse import State
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn.parameter import Parameter
from tqdm import tqdm
from copy import deepcopy
from typing import Dict, Iterator, List, Optional
from src.datasets.common import maybe_dictionarize
from src.route_merged_model import RouteMergedModel
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
import os

log = logging.getLogger(__name__)
# 简单定义 StateDict 为 dict 类型
StateDict = dict

def draw_distribution_tie(
    merged_model: nn.Module,
    SFT_model: nn.Module, 
    dataloader: Optional[torch.utils.data.DataLoader],
    device: str,
    dataset_name: str,
    type: str,
):
    SAVE_ROUTE = '/home/pzc/OT_fusion/results/' + type
    """
    使用T-SNE画出SFT模型和融合模型的feature distribution，在一张图上呈现
    """
    merged_model = merged_model.to(device)
    SFT_model = SFT_model.to(device)

    # 2. 获取特征
    features_SFT = []
    features_merged = []
    
    for i, data in enumerate(pbar := tqdm(dataloader, leave=False)):
        data: Dict[str, Tensor] = maybe_dictionarize(data)
        x = data["images"].to(device)
        y = data["labels"].to(device)

        # 获取SFT模型的特征
        with torch.no_grad():
            feature_SFT = SFT_model(x).cpu()
            features_SFT.append(feature_SFT)

        # 获取融合模型的特征
        with torch.no_grad():
            feature_merged = merged_model(x).cpu()
            features_merged.append(feature_merged)
        
        if i >= 200:
            break
    
    # 合并所有特征
    features_SFT = torch.cat(features_SFT, dim=0).numpy()
    features_merged = torch.cat(features_merged, dim=0).numpy()

    # 3. 使用T-SNE降维
    features_combined = np.concatenate([features_SFT, features_merged], axis=0)
    tsne = TSNE(n_components=2, random_state=42)
    features_tsne = tsne.fit_transform(features_combined)

    # 4. 获取标签（0表示SFT模型，1表示融合模型）
    labels = np.concatenate([np.zeros(features_SFT.shape[0]), np.ones(features_merged.shape[0])], axis=0)

    # 5. 绘制散点图
    plt.figure(figsize=(6, 5))
    plt.scatter(features_tsne[labels == 0, 0], features_tsne[labels == 0, 1], c='blue', label='SFT Model', alpha=0.6)
    plt.scatter(features_tsne[labels == 1, 0], features_tsne[labels == 1, 1], c='red', label='Merged Model', alpha=0.6)
    
    # 设置标题和坐标轴标签的字体大小
    plt.title(f'{dataset_name}', fontsize=20)  # 使用dataset_name作为标题，并增大字体

    #增大坐标轴刻度字体
    plt.xticks(fontsize=18)  # 增大x轴刻度字体
    plt.yticks(fontsize=18)  # 增大y轴刻度字体
    # 增大图例样式的字体
    plt.legend(fontsize=18)  # 增大图例的字体

    # 6. 保存图像
    plt.savefig(os.path.join(SAVE_ROUTE, f'{dataset_name}_feature_distribution.png'))
    plt.close()
    print(1111111111111111)

