import torch,os,os,json
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity

class QwenAnalyzer:
    def __init__(self, model_path, device="cuda" if torch.cuda.is_available() else "cpu"):
        """
        初始化Qwen分析器
        :param model_path: 模型路径（如"Qwen/Qwen-7B"）
        :param device: 运行设备
        """
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path, 
            trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            trust_remote_code=True,
            torch_dtype=torch.float16
        ).to(self.device)
        self.model.eval()

    def get_log_probs(self, prompt, max_new_tokens=50):
        """
        获取生成文本的token级log probabilities
        :param prompt: 输入文本
        :param max_new_tokens: 最大生成token数
        :return: 包含生成文本、tokens、token IDs和log probabilities的字典
        """
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            output_scores=True,
            return_dict_in_generate=True,
            do_sample=False
        )
        
        generated_ids = outputs.sequences[0, len(inputs["input_ids"][0]):]
        
        log_probs = []
        for i, token_id in enumerate(generated_ids):
            logits = outputs.scores[i]
            log_prob = torch.log_softmax(logits, dim=-1)
            log_probs.append(log_prob[0, token_id].item())
        
        return {
            "generated_text": self.tokenizer.decode(generated_ids),
            "tokens": [self.tokenizer.decode(t) for t in generated_ids],
            "token_ids": generated_ids.tolist(),
            "log_probs": log_probs
        }

    def get_hidden_states(self, prompt, layer_index=-1):
        """
        获取指定层的隐藏状态
        :param prompt: 输入文本
        :param layer_index: 层索引（-1表示最后一层）
        :return: 隐藏状态和对应的token IDs
        """
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
        
        hidden_states = outputs.hidden_states[layer_index]
        return {
            "hidden_states": hidden_states[0].cpu().to(torch.float64).numpy(),
            "input_ids": inputs["input_ids"][0].cpu().to(torch.float64).numpy()
        }

    def get_attention_maps(self, prompt):
        """
        获取注意力权重矩阵
        :param prompt: 输入文本
        :return: 各层的注意力权重矩阵
        """
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs, output_attentions=True)
        
        # attentions是一个元组，包含每一层的注意力矩阵
        # 每个矩阵形状为 [batch_size, num_heads, seq_len, seq_len]
        attentions = outputs.attentions
        return {
            "attentions": [attn[0].cpu().numpy() for attn in attentions],  # 取第一个样本
            "tokens": self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
        }

    @staticmethod
    def visualize_embeddings(embeddings, labels=None, method="pca", title="Embedding Projection", fig_name='1'):
        """
        可视化embedding的降维结果
        :param embeddings: 要可视化的embedding数组
        :param labels: 可选标签
        :param method: 降维方法 ("pca" 或 "tsne")
        :param title: 图表标题
        """
        if method.lower() == "pca":
            reducer = PCA(n_components=2)
        elif method.lower() == "tsne":
            reducer = TSNE(n_components=2, perplexity=30, n_iter=300)
        else:
            raise ValueError("Method must be 'pca' or 'tsne'")
        
        # 计算平均embedding（每个token的embedding）
        if len(embeddings.shape) > 2:
            embeddings = embeddings.mean(axis=1)  # 对序列长度维度取平均
        
        projected = reducer.fit_transform(embeddings)
        
        plt.figure(figsize=(10, 8))
        if labels is not None:
            scatter = plt.scatter(projected[:, 0], projected[:, 1], c=labels)
            plt.colorbar(scatter)
        else:
            plt.scatter(projected[:, 0], projected[:, 1])
        plt.title(title)
        plt.xlabel(f"{method.upper()} 1")
        plt.ylabel(f"{method.upper()} 2")
        plt.show()
        plt.savefig(f'{fig_name}.png', format='png')

    @staticmethod
    def plot_attention_heatmap(attention, tokens_x, tokens_y, layer=0, head=0, title="Attention Heatmap", fig_name='1'):
        """
        绘制注意力热力图
        :param attention: 注意力权重数组 (num_layers, num_heads, seq_len, seq_len)
        :param tokens_x: x轴token标签
        :param tokens_y: y轴token标签
        :param layer: 要可视化的层
        :param head: 要可视化的头
        :param title: 图表标题
        """
        plt.figure(figsize=(12, 10))
        ax = sns.heatmap(
            attention[layer][head], 
            xticklabels=tokens_x,
            yticklabels=tokens_y,
            cmap="viridis",
            square=True
        )
        ax.set_title(f"{title}\nLayer {layer}, Head {head}")
        plt.xticks(rotation=45, ha="right")
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.show()
        plt.savefig(f'{fig_name}.png', format='png')

def compare_finetuned_models(original_model_path, finetuned_model_path, prompts, fig_name1, fig_name2):
    """
    比较原始模型和微调后的模型
    :param original_model_path: 原始模型路径
    :param finetuned_model_path: 微调后模型路径
    :param prompts: 要分析的提示列表
    """
    # 加载模型
    original_analyzer = QwenAnalyzer(original_model_path)
    finetuned_analyzer = QwenAnalyzer(finetuned_model_path)
    
    # 收集所有hidden states用于可视化
    all_original_embeddings = []
    all_finetuned_embeddings = []
    
    cosine_similarity = 0
    
    for prompt in prompts:
        print(f"\nAnalyzing prompt: {prompt}")
        
        # 获取log probabilities
        orig_log_probs = original_analyzer.get_log_probs(prompt)
        ft_log_probs = finetuned_analyzer.get_log_probs(prompt)
        
        print(f"Original output: {orig_log_probs['generated_text']}")
        print(f"Finetuned output: {ft_log_probs['generated_text']}")
        
        # 获取hidden states
        orig_hidden = original_analyzer.get_hidden_states(prompt)
        ft_hidden = finetuned_analyzer.get_hidden_states(prompt)
        
        all_original_embeddings.append(orig_hidden["hidden_states"])
        all_finetuned_embeddings.append(ft_hidden["hidden_states"])
        cosine_similarity += \
            ((orig_hidden["hidden_states"] - ft_hidden["hidden_states"])**2).sum() 

        
        # 获取注意力权重
        orig_attentions = original_analyzer.get_attention_maps(prompt)
        ft_attentions = finetuned_analyzer.get_attention_maps(prompt)
        
        # 可视化第一个prompt的注意力
        # if prompt == prompts[0]:
        #     print("\nVisualizing attention maps for first prompt...")
        #     QwenAnalyzer.plot_attention_heatmap(
        #         orig_attentions["attentions"],
        #         orig_attentions["tokens"],
        #         orig_attentions["tokens"],
        #         title="Original Model Attention"
        #     )
        #     QwenAnalyzer.plot_attention_heatmap(
        #         ft_attentions["attentions"],
        #         ft_attentions["tokens"],
        #         ft_attentions["tokens"],
        #         title="Finetuned Model Attention"
        #     )
    

    

    
    # 合并所有prompt的embedding进行可视化
    original_embeddings = np.concatenate(all_original_embeddings)
    finetuned_embeddings = np.concatenate(all_finetuned_embeddings)
    
    print("\nVisualizing embedding spaces...")
    QwenAnalyzer.visualize_embeddings(
        original_embeddings,
        method="pca",
        title="Original Model Embedding Space (PCA)",
        fig_name = fig_name1
    )
    QwenAnalyzer.visualize_embeddings(
        finetuned_embeddings,
        method="pca",
        title="Finetuned Model Embedding Space (PCA)",
        fig_name = fig_name2
    )
    
    
    score = cosine_similarity / len(prompts)
    return score.item()
    
    # 也可以使用t-SNE
    # QwenAnalyzer.visualize_embeddings(
    #     original_embeddings,
    #     method="tsne",
    #     title="Original Model Embedding Space (t-SNE)"
    # )
    # QwenAnalyzer.visualize_embeddings(
    #     finetuned_embeddings,
    #     method="tsne",
    #     title="Finetuned Model Embedding Space (t-SNE)"
    # )

if __name__ == "__main__":
    # original_model = ppp
    # finetuned_model = ppp  # 替换为您的微调模型路径
    
    # test_prompts = [
    #     "中国的首都是",
    # ]
    
    # compare_finetuned_models(original_model, finetuned_model, test_prompts,'1','2')
    # exit(0)
    # 示例使用
    
    # original_model = "Qwen/Qwen-7B"
    # finetuned_model = "./path/to/your/finetuned/model"  # 替换为您的微调模型路径
    
    folder_name = 'output/0510compare'
    top_k = 5
    from models_and_datas import models_and_datas_qwen as models
    from glue_utils import get_texts
    from run import simple_model_name_to_ckpt
    branches = ['sst2','cola','rte', 'qnli','mnli','mrpc','wnli','qqp']
    branches = sorted(branches)
    
    os.makedirs(folder_name, exist_ok = True)
    test_prompts = []
    prompt_format = '''<|im_start|>system
{}<|im_end|>
<|im_start|>user
{}<|im_end|>
<|im_start|>assistant
'''
    for i in branches:
        prompts = [prompt_format.format(t["instruction"],t["input"])\
            for t in get_texts(i)]
        test_prompts += prompts[:top_k]
    
    sim = {}
    
    for i in range(8):
        for j in range(i+1,8):
            a,b = branches[i],branches[j]
            model_a, model_b = models[a]['model'][0], models[b]['model'][0]
            model_a, model_b = simple_model_name_to_ckpt(model_a), simple_model_name_to_ckpt(model_b)
            
            sim[a+'-'+b] = compare_finetuned_models(model_a,\
                                     model_b,
                                     test_prompts, 
                                     folder_name+f'/{a}_in_{a}_{b}',
                                     folder_name+f'/{b}_in_{a}_{b}')

    with open('newcompare.json','w+') as f:
        json.dump(sim,f)
    