import os
import sys
sys.path.append('..')
import string
from character_tokenizer import CharacterTokenizer

from tasks import task_registry
from data import add_special_tokens, tokenization_train, PromptAnswerDataCollator
from data import tokenization_eval
import numpy as np

import torch
torch.no_grad().__enter__()
from transformers import AutoModelForCausalLM, LlamaForCausalLM
import matplotlib.pyplot as plt
import seaborn as sns
from trainers import LogitsProcessorWithLossMask
import functools
from tqdm import tqdm
from joblib import Memory

memory = Memory('~/.cache', verbose=0)

save_dir = 'plots/analysis'
os.makedirs(save_dir, exist_ok=True)


def get_sorted_checkpoint_paths(ckpt_dir, take_last=False):
    ckpt_paths = [os.path.join(ckpt_dir, d) for d in os.listdir(ckpt_dir) if d.startswith('checkpoint-')]
    ckpt_nums = [int(d.split('-')[-1]) for d in ckpt_paths]
    sorted_indices = np.argsort(ckpt_nums)
    sorted_paths = [ckpt_paths[i] for i in sorted_indices]
    sorted_nums = sorted(ckpt_nums)
    if take_last:
        sorted_paths = sorted_paths[-1:]
        sorted_nums = sorted_nums[-1:]
    return sorted_paths, sorted_nums

def get_data(tasks, test_lengths, eval=False):
    tokenizer = CharacterTokenizer(string.ascii_letters + string.digits + string.punctuation + ' ')
    tokenizer.padding_side == 'left'
    collator = PromptAnswerDataCollator(tokenizer=tokenizer)

    data = {}

    for task_id, task_name in tasks.items():
        batch = []
        rng = np.random.default_rng(42)
        
        for i in range(100):
            prompt, target, loss_mask = task_registry[task_name](rng=rng, **test_lengths)
            if loss_mask is None:
                loss_mask = [1] * len(target)
            ex = {'prompt': [prompt], 'target': [target], 'loss_mask': [loss_mask]}
            ex = add_special_tokens(ex, tokenizer, task_id=task_id)
            if eval:
                ex = tokenization_eval(ex, tokenizer)
            else:
                ex = tokenization_train(ex, tokenizer)
            key = 'eval_input_ids' if eval else 'input_ids'
            batch += [{key: value[i] for key, value in ex.items()} for i in range(len(ex[key]))]
        batch = collator(batch)
        data[task_id] = batch
    return data, tokenizer

@memory.cache
def get_hs_attn(ckpt_paths, tasks, test_lengths):
    data, tokenizer = get_data(tasks, test_lengths, eval=False)

    all_outputs = []

    for ckpt_path in tqdm(ckpt_paths, desc="Processing checkpoints for HS/Attn"):
        model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained(ckpt_path, torch_dtype='auto', device_map='auto')
        model.eval()

        outputs = {}

        for task_id, batch in data.items():
            output = model(**batch, return_dict=True, output_hidden_states=False, output_attentions=True)
            outputs[task_id] = {}
            outputs[task_id]['attentions'] = np.stack([output['attentions'][i].detach().float().numpy() for i in range(len(output['attentions']))])

        torch.cuda.empty_cache()

        all_outputs.append(outputs)

    all_attn_sims = []

    for i, ckpt_path in enumerate(ckpt_paths):
        outputs = all_outputs[i]

        num_layers, _, num_attn_heads, seq_len_A, _ = outputs['A']['attentions'].shape
        _, _, _, seq_len_B, _ = outputs['B']['attentions'].shape
        seq_len = min(seq_len_A, seq_len_B)

        attn_sim = []
        for l in range(num_layers):
            attn_sim.append([])
            for h in range(num_attn_heads):
                attn_sim[-1].append(
                    np.abs(outputs['A']['attentions'][l, :, h, :seq_len, :seq_len] -
                           outputs['B']['attentions'][l, :, h, :seq_len, :seq_len]).mean()
                )

        all_attn_sims.append(attn_sim)
        
    all_attn_sims = np.array(all_attn_sims).mean(axis=(1, 2))
    print(all_attn_sims.shape)
    
    return all_attn_sims

def compute_attn_diff_and_plot(ckpt_paths, ckpt_nums, all_attn_sims, ax=None, color=None):
    if ax is None:
        sns.set_theme(style='whitegrid')
        ax = plt.figure(figsize=(5, 3.5)).get_axes()[0]
    ax.plot(ckpt_nums, all_attn_sims, label='Attention difference', color=color)
    # plt.title(f'Attention difference between tasks')
    ax.set_xticks(ckpt_nums, ckpt_nums)
    ax.set_xlabel('Checkpoint')
    # ax.set_ylabel('Attention difference')
    ax.legend(loc='upper right')
    
    return ax

def _mean_ablation_hook(module, input_tuple, target_h, n_heads, h_dim):
    hidden_states = input_tuple[0]
    batch_size, seq_len, hidden_dim_ = hidden_states.shape

    reshaped_states = hidden_states.view(batch_size, seq_len, n_heads, h_dim)

    mean_val = torch.mean(reshaped_states[:, :, target_h, :]).item()

    modified_states = reshaped_states.clone()
    modified_states[:, :, target_h, :] = mean_val

    modified_states_flat = modified_states.view(batch_size, seq_len, hidden_dim_)

    # Return modified input tuple for the pre-hook
    return (modified_states_flat,) + input_tuple[1:]

def eval_batch(model, inputs, tokenizer):
    kwargs = {
        'max_new_tokens': len(inputs['labels'][0]) if model.generation_config.max_length is None else None
    }
    if 'loss_mask' not in inputs or (inputs['loss_mask'] == 1).all():
        logits_processor = None
    else:
        logits_processor = [LogitsProcessorWithLossMask(inputs, model.generation_config.num_beams)]
    outputs = model.generate(
        inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        logits_processor=logits_processor,
        do_sample=False,
        use_cache=True,
        **kwargs
    )
    outputs = outputs[:, len(inputs['input_ids'][0]):]
    outputs_pad = torch.ones_like(inputs['labels']) * tokenizer.pad_token_id
    outputs_pad[:, :outputs.shape[1]] = outputs
    outputs = outputs_pad

    # print(tokenizer.decode(inputs['input_ids'][0]))
    # print(tokenizer.decode(outputs[0]))
    # print(tokenizer.decode(inputs['labels'][0]))
    accuracy = (outputs == inputs['labels']).all(1).float().mean().item()
    # print(accuracy)
    return accuracy

@memory.cache
def evaluate_checkpoints(ckpt_paths, tasks, test_lengths):
    data, tokenizer = get_data(tasks, test_lengths, eval=True)

    eval_results = {task_id: [] for task_id in tasks}

    for i, ckpt_path in enumerate(tqdm(ckpt_paths, desc="Evaluating checkpoints")):
        print(ckpt_path)
        model = LlamaForCausalLM.from_pretrained(ckpt_path, torch_dtype='auto', device_map='auto')
        model.generation_config.pad_token_id = tokenizer.pad_token_id
        model.generation_config.eos_token_id = tokenizer.eos_token_id
        model.eval()

        for task_id, batch in data.items():
            accuracy = eval_batch(model, batch, tokenizer)
            eval_results[task_id].append(accuracy)
            
    return eval_results

def plot_accuracy(ckpt_nums, eval_results, ax=None, color=None):
    if ax is None:
        fig = plt.figure(figsize=(5, 3.5))
        ax = fig.get_axes()[0]

    ax.plot(ckpt_nums, np.array(eval_results['B']) - np.array(eval_results['A']), label=f'Generalization Gap', color=color)
    ax.set_ylim(-0.05, 1)

    ax.set_xticks(ckpt_nums)
    ax.set_xlabel('Checkpoint')
    # ax.set_ylabel('Generalization Gap')
    ax.legend(loc='upper right')

    return ax

@memory.cache
def ablate_heads_and_evaluate(ckpt_paths, tasks, test_lengths, eval_results):
    data, tokenizer = get_data(tasks, test_lengths, eval=True)

    model = LlamaForCausalLM.from_pretrained(ckpt_paths[0], torch_dtype='auto', device_map='auto')
    head_dim = model.config.hidden_size // model.config.num_attention_heads
    ablation_results = np.zeros((len(ckpt_paths), len(data), model.config.num_hidden_layers, model.config.num_attention_heads))
    
    for i, ckpt_path in enumerate(tqdm(ckpt_paths, desc="Ablating heads")):
        print(ckpt_path)
        model = LlamaForCausalLM.from_pretrained(ckpt_path, torch_dtype='auto', device_map='auto')
        model.eval()

        for j, (task_id, inputs) in enumerate(data.items()):
            for layer in tqdm(range(model.config.num_hidden_layers), desc=f"Task {task_id} Layers", leave=False):
                for head in range(model.config.num_attention_heads):
                    specific_ablation_hook = functools.partial(
                        _mean_ablation_hook,
                        target_h=head,
                        n_heads=model.config.num_attention_heads,
                        h_dim=head_dim
                    )
                    target_module = model.model.layers[layer].self_attn.o_proj
                    handle = target_module.register_forward_pre_hook(specific_ablation_hook)
                    
                    accuracy = eval_batch(model, inputs, tokenizer)
                    ablation_results[i, j, layer, head] = accuracy

                    handle.remove()

            plt.figure(figsize=(5, 4))
            sns.heatmap(ablation_results[i, j] - eval_results[task_id][i], annot=True, fmt='.2f')
            plt.savefig(os.path.join(save_dir, 'tmp', f'ablation_{task_id}_{ckpt_nums[i]}.pdf'), bbox_inches='tight')

    ablation_results = ablation_results.mean(axis=(2, 3))
    return ablation_results

def compute_ablation_difference(ckpt_nums, ablation_results, eval_results, ax=None, color=None):
    if ax is None:
        fig = plt.figure(figsize=(5, 4))
        ax = fig.get_axes()[0]
    ax.plot(ckpt_nums, np.abs(
        (ablation_results[:, 0] - np.array(eval_results['A'])) -
        (ablation_results[:, 1] - np.array(eval_results['B']))
    ), label='Head ablation map difference', color=color)
    ax.set_xticks(ckpt_nums)
    ax.set_xlabel('Checkpoint')
    # ax.set_ylabel('Mean ablation map difference')
    ax.legend(loc='upper right')
    ax.set_ylim(-0.05, 1)
    
    return ax

ckpt_dirs = [
    # '../out/-llama-384-6-6-1024-rope-copy-l=6_17-MQAR-l=6_33-SFT-seed-43',
    # '../out/-llama-384-6-6-1024-rope-capitalize_reverse-l=6_17-capitalize-l=6_33-reverse-l=6_33-SFT-seed-45',
    # '../out/-llama-384-6-6-1024-rope-copy-l=6_17-reverse-l=6_33-SFT-seed-44',
    # '../out/-llama-384-6-6-1024-rope-reverse_add-la=1_17-lb=1_17-reverse_sub-la=1_33-lb=1_33-SFT-seed-45',
    '../out/-llama-384-6-6-1024-rope-reverse_add-la=1_17-lb=1_17-reverse_add_no_carry-la=1_33-lb=1_33-reverse_add_only_carry-la=1_33-lb=1_33-SFT-seed-45',
    # '../out/-llama-384-6-6-1024-rope-reverse_add-la=1_17-lb=1_17-copy_first_op-la=1_33-lb=1_33-SFT-seed-43'
]

all_tasks = [
    # {'A': 'copy', 'B': 'MQAR'},
    # {'A': 'capitalize_reverse', 'B': 'capitalize'},
    # {'A': 'copy', 'B': 'reverse'},
    # {'A': 'reverse_add', 'B': 'reverse_sub'},
    {'A': 'reverse_add', 'B': 'reverse_add_no_carry'},
    # {'A': 'reverse_add', 'B': 'copy_first_op'}
]

test_lengths = [
    # {'l': [16, 33]},
    # {'l': [16, 33]},
    # {'l': [6, 33]},
    {'la': [16, 33]},
    # {'la': [16, 33]},
    # {'la': [4, 33]},
]

colors = sns.color_palette("tab10", 3)
for ckpt_dir, tasks, test_length_params in tqdm(zip(ckpt_dirs, all_tasks, test_lengths), total=len(ckpt_dirs), desc="Overall Progress"):
    ckpt_paths, ckpt_nums = get_sorted_checkpoint_paths(ckpt_dir)
    
    sns.set_theme(style='whitegrid')
    plt.rc('grid', linestyle='--')
    fig = plt.figure(figsize=(4, 5))
    axes = fig.subplots(3, 1, sharex=True)
    
    all_attn_sims = get_hs_attn(ckpt_paths, tasks, test_length_params)
    ax = compute_attn_diff_and_plot(ckpt_paths, ckpt_nums, all_attn_sims, ax=axes[0], color=colors[0])
    # ax.savefig(os.path.join(save_dir, f'attn_diff_{tasks["A"]}_{tasks["B"]}.pdf'), bbox_inches='tight')
    # plt.close(ax.figure())

    eval_results = evaluate_checkpoints(ckpt_paths, tasks, test_length_params)
    ax = plot_accuracy(ckpt_nums, eval_results, ax=axes[2], color=colors[1])
    # ax.savefig(os.path.join(save_dir, f'accuracy_diff_{tasks["A"]}_{tasks["B"]}.pdf'), bbox_inches='tight')
    # plt.close(ax.figure())

    ablation_results = ablate_heads_and_evaluate(ckpt_paths, tasks, test_length_params, eval_results)
    ax = compute_ablation_difference(ckpt_nums, ablation_results, eval_results, ax=axes[1], color=colors[2])
    # ax.savefig(os.path.join(save_dir, f'ablation_diff_{tasks["A"]}_{tasks["B"]}.pdf'), bbox_inches='tight')
    # plt.close(ax.figure())
    axes[0].set_xlabel(None)
    axes[1].set_xlabel(None)
    # Add ylabel and title to the left side of the figure
    # axes[0].set_ylabel("Attn Diff")
    # axes[1].set_ylabel("Ablation Diff")
    # axes[2].set_ylabel("Gen Gap")
    
    # Rotate the x-ticks for better readability
    for ax in axes:
        ax.tick_params(axis='x', rotation=45)
        
    # Adjust bottom margin to make room for rotated labels
    fig.subplots_adjust(bottom=0.15)
    fig.tight_layout(pad=0.0)  # Reduce padding between subfigures
    plt.savefig(os.path.join(save_dir, f'{tasks["A"]}_{tasks["B"]}.pdf'), bbox_inches='tight')
