import pdb
import json
import os
import torch
import torch.nn as nn
from torch.cuda import OutOfMemoryError
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.models.llama.modeling_llama import LlamaRMSNorm
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import patches


global layernorm_dim
layernorm_dim = 1


def get_model_and_tokenizer(model_name, padding_side="right"):
    if any(name in model_name.lower() for name in ["llama", "mistral"]):
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
        
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = padding_side
    elif "qwen" in model_name.lower():
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
        
        if hasattr(tokenizer, "eod_id"):
            tokenizer.pad_token_id = tokenizer.eod_id
            tokenizer.eos_token_id = tokenizer.eod_id
        else:
            tokenizer.pad_token = "<|endoftext|>"
            tokenizer.eos_token = "<|endoftext|>"
        
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = padding_side
    else:
        raise NotImplementedError(f"model name: {model_name}!")
    
    return model, tokenizer
        

class wrapped_model_delta_output:
    def __init__(self, layer, layer_id=0, layer_name="none", model_id=0):
        self.layer = layer
        self.dev = self.layer.weight.device
        layer_weight = layer.weight.data

        if layer_weight.ndim == 1:
            layer_weight = layer_weight.unsqueeze(layernorm_dim)

        self.rows = layer_weight.shape[0]
        try:
            self.columns = layer_weight.shape[1]
        except:
            pdb.set_trace()

        self.param_names = {i: f"{layer_name}_param_{i}" for i in range(self.rows)}

        self.scaler_columns = torch.zeros((self.rows), device="cpu")
        self.nsamples = 0

        self.layer_id = layer_id 
        self.model_id = model_id
        self.layer_name = layer_name

        self.current_inp = None
        self.current_out = None
    
    def add_batch(self, inp, out):
        self.current_inp = inp.detach().cpu()
        self.current_out = out.detach().cpu()
    
    def update(self, base_inp, base_out):
        try:
            if self.current_inp.shape == base_inp.shape and self.current_out.shape == base_out.shape:
                self.current_inp = self.current_inp - base_inp
                self.current_out = self.current_out - base_out
                self.get_delta_scale_columns(self.current_inp, self.current_out)
            else:
                if self.current_inp.shape != base_inp.shape and self.current_out.shape == base_out.shape:
                    inp_dim = self.current_inp.shape[1]
                    base_dim = base_inp.shape[1]

                    if inp_dim > base_dim:
                        self.current_inp = self.current_inp[:, :base_dim, :] - base_inp
                    else:
                        self.current_inp = self.current_inp - base_inp[:, :inp_dim, :]
        
                if self.current_inp.shape == base_inp.shape and self.current_out.shape != base_out.shape:
                    print(f"out shape not match: {self.current_out.shape}, {base_out.shape}")
                    out_dim = self.current_out.shape[1]
                    base_dim = base_out.shape[1]

                    if out_dim > base_dim:
                        self.current_out = self.current_out[:, :base_dim, :] - base_out
                    else:
                        cur_out_dim2 = self.current_out.shape[2]
                        base_out_dim2 = base_out.shape[2]
                        if cur_out_dim2 == base_out_dim2:
                            self.current_out = self.current_out - base_out[:, :out_dim, :]
                        elif cur_out_dim2 > base_out_dim2:
                            self.current_out = self.current_out[:, :, :base_out_dim2] - base_out[:, :out_dim, :]
                        else:
                            self.current_out = self.current_out - base_out[:, :out_dim, :cur_out_dim2]
        except:
            pdb.set_trace()

        self.get_delta_scale_columns(self.current_inp, self.current_out) 
            
    def get_delta_scale_columns(self, inp, out):
        if len(out.shape) == 2:
            out = out.unsqueeze(0)
        tmp = out.shape[0]
        if isinstance(self.layer, nn.Linear):
            if len(out.shape) == 3:
                out = out.reshape((-1, out.shape[-1]))
            out = out.t()
        elif isinstance(self.layer, LlamaRMSNorm):
            if len(out.shape) == 3:
                out = out.reshape((-1, self.rows))
            out = out.t()

        self.scaler_columns *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp

        out = out.type(torch.bfloat16)
        try:
            self.scaler_columns += torch.norm(out, p=2, dim=1) ** 2  / self.nsamples
        except:
            pdb.set_trace()


def get_matching_layer_names(model1, model2):
    layers1 = {name for name, module in model1.named_modules() if isinstance(module, nn.Linear)}
    layers2 = {name for name, module in model2.named_modules() if isinstance(module, nn.Linear)}
    return layers1.intersection(layers2)


def register_hooks(model, reference_layers=None):
    hooks = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear)):
            if reference_layers is None or name in reference_layers:
                hook = wrapped_model_delta_output(module, layer_name=name)
                module.register_forward_hook(lambda m, i, o, h=hook: h.add_batch(i[0], o))
                hooks.append(hook)
    return hooks


def get_math_task_prompt(instruction):
    return f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response: Let's think step by step."


def generate_code_task_prompt(input_text):
    return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{input_text}\n\n### Response:"""


def prepare_inputs(data, tokenizer, dataset_name, use_template):
    if not use_template:    # default
        tok_out = tokenizer(
            f"{data['query']}\n{data['response']}",
            return_tensors='pt',
            add_special_tokens=False,
        )
    else:
        if "gsm8k" in dataset_name:
            tok_out = tokenizer(f"{get_math_task_prompt(data['query'])}{data['response']}", return_tensors='pt', add_special_tokens=False)
        elif "mbpp" in dataset_name:
            tok_out = tokenizer(f"{generate_code_task_prompt(data['query'])}\n{data['response']}", return_tensors='pt', add_special_tokens=False)
    return tok_out


def plot_global_record_scaler_columns(hooks, dataset_name, model_name, use_template, save_dir, top_n_rows):
    father_dir = save_dir
    all_top_info = {}

    for i, hook in enumerate(hooks):
        plt.figure(figsize=(12, 6))  # Increased figure size for better readability
        scaler_values = hook.scaler_columns.detach().cpu().numpy()
        top_indices = np.argsort(scaler_values)[-top_n_rows:]

        plt.plot(scaler_values)
        plt.title(f"Layer: {hook.layer_name}", fontsize=20)
        plt.xlabel("Column Index", fontsize=20)
        plt.ylabel("Scaler Value", fontsize=20)
        
        # Adjust tick label sizes
        plt.tick_params(axis='both', labelsize=20)

        for index in top_indices:
            plt.scatter(index, scaler_values[index], s=100, c='red')
            plt.annotate(str(index), 
                        xy=(index, scaler_values[index]),
                        xytext=(index + 1, scaler_values[index] + 0.1),
                        arrowprops=dict(facecolor='black', shrink=0.05),
                        horizontalalignment='right', 
                        verticalalignment='bottom',
                        fontsize=18)

        # Directory setup remains the same
        if not use_template:
            save_dir_path = f"{father_dir}/{model_name}/{hook.layer_name}"
        else:
            save_dir_path = f"{father_dir}-use_template/{model_name}/{hook.layer_name}"
        os.makedirs(save_dir_path, exist_ok=True)

        # Save figure and data (unchanged)
        save_path = f"{save_dir_path}/{dataset_name}_scaler_columns.pdf"
        plt.savefig(save_path, dpi=300, bbox_inches='tight', format='pdf')
        plt.close()

        # Save scaler values and top info (unchanged)
        np.save(f"{save_dir_path}/{dataset_name}_scaler_values.npy", scaler_values)
        
        top_info = {
            "layer_name": hook.layer_name,
            "top_indices": top_indices.tolist(),
            "top_scaler_values": scaler_values[top_indices].tolist(),
            "param_names": [hook.param_names[idx] for idx in top_indices]
        }
        with open(f"{save_dir_path}/{dataset_name}_top_info.json", "w") as f:
            json.dump(top_info, f, indent=4)
        
        all_top_info[hook.layer_name] = top_info

    # Save all top info (unchanged)
    all_top_save_path = f"{father_dir}/{model_name}/{dataset_name}_all_top_info.json"
    with open(all_top_save_path, "w") as f:
        json.dump(all_top_info, f, indent=4)
    
    print(f"Save all layers top-{top_n_rows} info in {all_top_save_path}")
    print(f"All figures and top-{top_n_rows} info have been saved in {father_dir}/")

def main(base_model_path, model_path, dataset_list, use_template, save_dir, top_n_rows=100):
    model, tokenizer = get_model_and_tokenizer(model_path)
    if "mask_merging" not in model_path:
        model_name = model_path.split("/")[-3]
    else:
        model_name = model_path.split("/")[-1]
    
    model_base, tokenizer_base = get_model_and_tokenizer(base_model_path)

    for dataset_path in dataset_list:
        dataset_name = dataset_path.split("/")[-1].split(".")[0]
        dataloader = [json.loads(line) for line in open(dataset_path, "r")] # json-format: dataloader = json.load(open(dataset_path, "r"))
        
        matching_layers = get_matching_layer_names(model, model_base)
        hooks = register_hooks(model, matching_layers)
        hooks_base = register_hooks(model_base, matching_layers)
        
        for data in tqdm(dataloader[:]):
            inputs = prepare_inputs(data, tokenizer, dataset_name, use_template=use_template)
            outputs = model(**inputs)
            inputs_base = prepare_inputs(data, tokenizer_base, dataset_name, use_template=use_template)
            outputs_base = model_base(**inputs_base)
            for hook, hook_base in zip(hooks, hooks_base):
                assert hook.layer_name == hook_base.layer_name
                base_current_inp = hook_base.current_inp
                base_current_out = hook_base.current_out
                hook.update(base_current_inp, base_current_out)
        
        plot_global_record_scaler_columns(hooks, dataset_name, model_name, use_template=use_template, save_dir=save_dir, top_n_rows=top_n_rows)


import signal
import sys
import traceback

def signal_handler(signum, frame):
    print(f"Received signal: {signum}. Exiting gracefully...")
    traceback.print_stack(frame)
    sys.exit(0)

def debug_main(base_model_path, model_path, dataset_list, use_template, save_dir, top_n_rows):
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    try:
        print("Running... Press Ctrl+C to exit.")
        while True:
            main(base_model_path, model_path, dataset_list, use_template, save_dir, top_n_rows=top_n_rows)
            break
        
    except Exception as e:
        print(f"An exception occurred: {e}")
        traceback.print_exc()
        


if __name__ == "__main__":
    llama_model_path = "Llama/Llama3/Meta-Llama-3-8B"
    mistral_model_path = "Mistral/Mistral-7B"
    qwen2_5_model_path = "Qwen/Qwen2.5/Qwen2.5-7B"

    llama3_model_path_list_20k = [
        "ckpt_sft/llama3-8b/infinity_math_20k/v0-20241226-185450/checkpoint-231",
    ]
    
    mistral_model_path_list_20k = [
        "ckpt_sft/mistral-7b/infinity_code_20k/v1-20241229-151803/checkpoint-231",
    ]

    qwen2_5_model_path_list_20k = [
        "ckpt_sft/qwen2.5-14b/csqa_20k/v0-20250126-190756/checkpoint-231",
    ]
    
    data_list_sample_from_20k = [
        "data/infinity_math_sample_50_seed_42.jsonl",
    ]
    
    top_n_rows = 100
    use_template = False
    base_model = qwen2_5_model_path.split('/')[-1]
    save_dir = f"dsr_dist_visualization-top_{top_n_rows}_rows-{base_model}"
    for model_path in qwen2_5_model_path_list_20k:
        debug_main(qwen2_5_model_path, model_path, data_list_sample_from_20k, use_template=use_template, save_dir=save_dir, top_n_rows=top_n_rows)
    
    print(f"Visualization finished!")
