import torch,os,os,json
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity

class QwenAnalyzer:
    def __init__(self, model_path, device="cuda" if torch.cuda.is_available() else "cpu"):
        """
        初始化Qwen分析器
        :param model_path: 模型路径（如"Qwen/Qwen-7B"）
        :param device: 运行设备
        """
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path, 
            trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            trust_remote_code=True,
            torch_dtype=torch.float16
        ).to(self.device)
        self.model.eval()

    def get_log_probs(self, prompt, max_new_tokens=50):
        """
        获取生成文本的token级log probabilities
        :param prompt: 输入文本
        :param max_new_tokens: 最大生成token数
        :return: 包含生成文本、tokens、token IDs和log probabilities的字典
        """
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            output_scores=True,
            return_dict_in_generate=True,
            do_sample=False
        )
        
        generated_ids = outputs.sequences[0, len(inputs["input_ids"][0]):]
        
        log_probs = []
        for i, token_id in enumerate(generated_ids):
            logits = outputs.scores[i]
            log_prob = torch.log_softmax(logits, dim=-1)
            log_probs.append(log_prob[0, token_id].item())
        
        return {
            "generated_text": self.tokenizer.decode(generated_ids),
            "tokens": [self.tokenizer.decode(t) for t in generated_ids],
            "token_ids": generated_ids.tolist(),
            "log_probs": log_probs
        }

    def get_hidden_states(self, prompt, layer_index=-1):
        """
        获取指定层的隐藏状态
        :param prompt: 输入文本
        :param layer_index: 层索引（-1表示最后一层）
        :return: 隐藏状态和对应的token IDs
        """
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
        
        hidden_states = outputs.hidden_states[layer_index]
        return {
            "hidden_states": hidden_states[0].cpu().to(torch.float64).numpy(),
            "input_ids": inputs["input_ids"][0].cpu().to(torch.float64).numpy()
        }

    def get_attention_maps(self, prompt):
        """
        获取注意力权重矩阵
        :param prompt: 输入文本
        :return: 各层的注意力权重矩阵
        """
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs, output_attentions=True)
        
        # attentions是一个元组，包含每一层的注意力矩阵
        # 每个矩阵形状为 [batch_size, num_heads, seq_len, seq_len]
        attentions = outputs.attentions
        return {
            "attentions": [attn[0].cpu().numpy() for attn in attentions],  # 取第一个样本
            "tokens": self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
        }

    @staticmethod
    def visualize_embeddings(embeddings, labels=None, method="pca", title="Embedding Projection", fig_name='1'):
        """
        可视化embedding的降维结果
        :param embeddings: 要可视化的embedding数组
        :param labels: 可选标签
        :param method: 降维方法 ("pca" 或 "tsne")
        :param title: 图表标题
        """
        if method.lower() == "pca":
            reducer = PCA(n_components=2)
        elif method.lower() == "tsne":
            reducer = TSNE(n_components=2, perplexity=30, n_iter=300)
        else:
            raise ValueError("Method must be 'pca' or 'tsne'")
        
        # 计算平均embedding（每个token的embedding）
        if len(embeddings.shape) > 2:
            embeddings = embeddings.mean(axis=1)  # 对序列长度维度取平均
        
        projected = reducer.fit_transform(embeddings)
        
        plt.figure(figsize=(10, 8))
        if labels is not None:
            scatter = plt.scatter(projected[:, 0], projected[:, 1], c=labels)
            plt.colorbar(scatter)
        else:
            plt.scatter(projected[:, 0], projected[:, 1])
        plt.title(title)
        plt.xlabel(f"{method.upper()} 1")
        plt.ylabel(f"{method.upper()} 2")
        plt.show()
        plt.savefig(f'{fig_name}.png', format='png')

    @staticmethod
    def plot_attention_heatmap(attention, tokens_x, tokens_y, layer=0, head=0, title="Attention Heatmap", fig_name='1'):
        """
        绘制注意力热力图
        :param attention: 注意力权重数组 (num_layers, num_heads, seq_len, seq_len)
        :param tokens_x: x轴token标签
        :param tokens_y: y轴token标签
        :param layer: 要可视化的层
        :param head: 要可视化的头
        :param title: 图表标题
        """
        plt.figure(figsize=(12, 10))
        ax = sns.heatmap(
            attention[layer][head], 
            xticklabels=tokens_x,
            yticklabels=tokens_y,
            cmap="viridis",
            square=True
        )
        ax.set_title(f"{title}\nLayer {layer}, Head {head}")
        plt.xticks(rotation=45, ha="right")
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.show()
        plt.savefig(f'{fig_name}.png', format='png')
from evaluations import test_glue
def compare_finetuned_models(original_model_path, datasets):
    score,_ = test_glue(original_model_path, datasets, 'cuda:1')
    return score
    
    # 也可以使用t-SNE
    # QwenAnalyzer.visualize_embeddings(
    #     original_embeddings,
    #     method="tsne",
    #     title="Original Model Embedding Space (t-SNE)"
    # )
    # QwenAnalyzer.visualize_embeddings(
    #     finetuned_embeddings,
    #     method="tsne",
    #     title="Finetuned Model Embedding Space (t-SNE)"
    # )

if __name__ == "__main__":

    folder_name = 'output/0510compare'
    top_k = 5
    from models_and_datas import models_and_datas_qwen as models
    from glue_utils import get_texts
    from run import simple_model_name_to_ckpt
    branches = ['sst2','cola','rte', 'qnli','mnli','mrpc','wnli','qqp']
    branches = sorted(branches)
    

    sim = {}
    
    for i in range(8):
        for j in range(i+1,8):
            a,b = branches[i],branches[j]
            model_a, model_b = models[a]['model'][0], models[b]['model'][0]
            model_a, model_b = simple_model_name_to_ckpt(model_a), simple_model_name_to_ckpt(model_b)
            
            sim[a+'-'+b] = compare_finetuned_models(model_a,\
                                     b,)

    with open('rabcompare.json','w+') as f:
        json.dump(sim,f)
    