import torch,os,os,json
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from huggingface_hub import login
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity
from peft import PeftModel

base_model = "mistralai/Mistral-7B-Instruct-v0.2"

class QwenAnalyzer:
    def __init__(self, model_path, device="cuda" if torch.cuda.is_available() else "cpu"):
        """
        初始化Qwen分析器
        :param model_path: 模型路径（如"Qwen/Qwen-7B"）
        :param device: 运行设备
        """
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(
            base_model, 
            trust_remote_code=True
        )
        self.model = PeftModel.from_pretrained(
            AutoModelForCausalLM.from_pretrained(base_model), 
            model_path,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16).to(self.device)
                
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16
        ).to(self.device)
        self.model.eval()

    def get_log_probs(self, prompt, max_new_tokens=50):
        """
        获取生成文本的token级log probabilities
        :param prompt: 输入文本
        :param max_new_tokens: 最大生成token数
        :return: 包含生成文本、tokens、token IDs和log probabilities的字典
        """
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            output_scores=True,
            return_dict_in_generate=True,
            do_sample=False
        )
        
        generated_ids = outputs.sequences[0, len(inputs["input_ids"][0]):]
        
        log_probs = []
        for i, token_id in enumerate(generated_ids):
            logits = outputs.scores[i]
            log_prob = torch.log_softmax(logits, dim=-1)
            log_probs.append(log_prob[0, token_id].item())
        
        return {
            "generated_text": self.tokenizer.decode(generated_ids),
            "tokens": [self.tokenizer.decode(t) for t in generated_ids],
            "token_ids": generated_ids.tolist(),
            "log_probs": log_probs
        }

    def get_hidden_states(self, prompt, layer_index=-1):
        """
        获取指定层的隐藏状态
        :param prompt: 输入文本
        :param layer_index: 层索引（-1表示最后一层）
        :return: 隐藏状态和对应的token IDs
        """
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
        
        hidden_states = outputs.hidden_states[layer_index]
        return {
            "hidden_states": hidden_states[0].cpu().to(torch.float64).numpy(),
            "input_ids": inputs["input_ids"][0].cpu().to(torch.float64).numpy()
        }

    def get_attention_maps(self, prompt):
        """
        获取注意力权重矩阵
        :param prompt: 输入文本
        :return: 各层的注意力权重矩阵
        """
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs, output_attentions=True)
        
        # attentions是一个元组，包含每一层的注意力矩阵
        # 每个矩阵形状为 [batch_size, num_heads, seq_len, seq_len]
        attentions = outputs.attentions
        return {
            "attentions": [attn[0].cpu().numpy() for attn in attentions],  # 取第一个样本
            "tokens": self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
        }

    @staticmethod
    def visualize_embeddings(embeddings, labels=None, method="pca", title="Embedding Projection", fig_name='1'):
        """
        可视化embedding的降维结果
        :param embeddings: 要可视化的embedding数组
        :param labels: 可选标签
        :param method: 降维方法 ("pca" 或 "tsne")
        :param title: 图表标题
        """
        if method.lower() == "pca":
            reducer = PCA(n_components=2)
        elif method.lower() == "tsne":
            reducer = TSNE(n_components=2, perplexity=30, n_iter=300)
        else:
            raise ValueError("Method must be 'pca' or 'tsne'")
        
        # 计算平均embedding（每个token的embedding）
        if len(embeddings.shape) > 2:
            embeddings = embeddings.mean(axis=1)  # 对序列长度维度取平均
        
        projected = reducer.fit_transform(embeddings)
        
        plt.figure(figsize=(10, 8))
        if labels is not None:
            scatter = plt.scatter(projected[:, 0], projected[:, 1], c=labels)
            plt.colorbar(scatter)
        else:
            plt.scatter(projected[:, 0], projected[:, 1])
        plt.title(title)
        plt.xlabel(f"{method.upper()} 1")
        plt.ylabel(f"{method.upper()} 2")
        plt.show()
        plt.savefig(f'{fig_name}.png', format='png')

    @staticmethod
    def plot_attention_heatmap(attention, tokens_x, tokens_y, layer=0, head=0, title="Attention Heatmap", fig_name='1'):
        """
        绘制注意力热力图
        :param attention: 注意力权重数组 (num_layers, num_heads, seq_len, seq_len)
        :param tokens_x: x轴token标签
        :param tokens_y: y轴token标签
        :param layer: 要可视化的层
        :param head: 要可视化的头
        :param title: 图表标题
        """
        plt.figure(figsize=(12, 10))
        ax = sns.heatmap(
            attention[layer][head], 
            xticklabels=tokens_x,
            yticklabels=tokens_y,
            cmap="viridis",
            square=True
        )
        ax.set_title(f"{title}\nLayer {layer}, Head {head}")
        plt.xticks(rotation=45, ha="right")
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.show()
        plt.savefig(f'{fig_name}.png', format='png')

def compare_finetuned_models(original_model_path, finetuned_model_path, prompts, fig_name1, fig_name2):
    """
    比较原始模型和微调后的模型
    :param original_model_path: 原始模型路径
    :param finetuned_model_path: 微调后模型路径
    :param prompts: 要分析的提示列表
    """
    # 加载模型
    original_analyzer = QwenAnalyzer(original_model_path)
    finetuned_analyzer = QwenAnalyzer(finetuned_model_path)
    
    # 收集所有hidden states用于可视化
    all_original_embeddings = []
    all_finetuned_embeddings = []
    
    cosine_similarity = 0
    
    for prompt in prompts:
        print(f"\nAnalyzing prompt: {prompt}")
        
        # 获取log probabilities
        orig_log_probs = original_analyzer.get_log_probs(prompt)
        ft_log_probs = finetuned_analyzer.get_log_probs(prompt)
        
        print(f"Original output: {orig_log_probs['generated_text']}")
        print(f"Finetuned output: {ft_log_probs['generated_text']}")
        
        # 获取hidden states
        orig_hidden = original_analyzer.get_hidden_states(prompt)
        ft_hidden = finetuned_analyzer.get_hidden_states(prompt)
        
        all_original_embeddings.append(orig_hidden["hidden_states"])
        all_finetuned_embeddings.append(ft_hidden["hidden_states"])
        cosine_similarity += ((orig_hidden["hidden_states"] - ft_hidden["hidden_states"])**2).sum() 
        

    score = cosine_similarity / len(prompts)
    return score.item()
    


if __name__ == "__main__":
    import argparse
    from datasets import load_dataset
    parser = argparse.ArgumentParser()
    parser.add_argument("--models_and_datas", type=str, default="models_and_datas_full_qwen_3b")
    parser.add_argument("--folder_name", type=str, default='output/0512newloracompare/1')
    parser.add_argument("--top_k", type=int, default=5)
    args = parser.parse_args()
    
    from models_and_datas import get_models_and_datas
    models = get_models_and_datas(args.models_and_datas)
    from glue_utils import get_texts
    from run import simple_model_name_to_ckpt
    branches = list(models.keys())[1:]
    branches = sorted(branches)
    folder_name = args.folder_name
    top_k = args.top_k
    
    os.makedirs(folder_name, exist_ok = True)
    test_prompts = []
    prompt_format = '''<s>[INST] {} [/INST] '''
    for i in branches:
        dataset = models[i]['datasets'][0]
        ds_val = load_dataset(dataset, split='test')
        prompts = [prompt_format.format(t["input"])\
            for t in ds_val]
        test_prompts += prompts[:top_k]
    
    sim = {}
    
    for i in range(8):
        for j in range(i+1,8):
            a,b = branches[i],branches[j]
            model_a, model_b = models[a]['model'][0], models[b]['model'][0]
            model_a, model_b = simple_model_name_to_ckpt(model_a), simple_model_name_to_ckpt(model_b)
            
            sim[a+'-'+b] = compare_finetuned_models(model_a,\
                                     model_b,
                                     test_prompts, 
                                     folder_name+f'/{a}_in_{a}_{b}',
                                     folder_name+f'/{b}_in_{a}_{b}')

    with open(f'{folder_name}_compare.json','w+') as f:
        json.dump(sim,f)
    