import os
import pdb
import random
from itertools import product

from typing import Dict, List, Tuple, Any, Optional, Union, Callable
import argparse

import tqdm
from jaxtyping import Float
import numpy as np
import torch
from tinycss2.nth import N_DASH_DIGITS_RE
from torch import Tensor
from torch.utils.data import Subset, Dataset
import einops
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorWithPadding
import nnsight
from nnsight import CONFIG, LanguageModel, util
from transformer_lens import HookedTransformer

import matplotlib.pyplot as plt
import seaborn as sns
import plotly
import plotly.express as px
import plotly.io as pio
from transformers.pytorch_utils import find_pruneable_heads_and_indices

pio.renderers.default = "plotly_mimetype+notebook_connected+colab+notebook"

import sys

sys.path.append("../pp_experiment")
from utils import get_model_and_tokenizer, load_dataloader, get_random_guess_baseline, fix_random_seed, str_to_bool, \
    compute_topk_components, find_previous_query_box_pos, is_int_with_negatives, stupid_pad
from run_patching import build_parser, post_arg_parse_fix, get_model_and_dataset, cache_logit_and_hidden, \
    maybe_logit_soft_capping, get_patch_score, plot_patching_results, maybe_patch_or_load_cache


def get_token_labels(clean_tokens: List[int], ctf_tokens: List[int], tokenizer: AutoTokenizer) -> List[str]:
    clean_decoded_tokens = [tokenizer.decode(token) for token in clean_tokens]
    ctf_decoded_tokens = [tokenizer.decode(token) for token in ctf_tokens]
    token_labels = []
    for idx, (clean_t, ctf_t) in enumerate(zip(clean_decoded_tokens, ctf_decoded_tokens)):
        if clean_t == ctf_t:
            token_labels.append(f"{clean_t}_{idx}")
        else:
            token_labels.append(f"{ctf_t}->{clean_t}_{idx}")
    return token_labels


def activation_patching_residual_stream_single_prompt(dataset, correct_index: List[int], model, args):
    clean_input_ids = stupid_pad([dataset['base_tokens']], model.tokenizer)
    N_TOKENS = len(clean_input_ids[0])
    clean_prompt = model.tokenizer.decode(dataset['base_tokens'], skip_special_tokens=True)
    corrupted_prompt = model.tokenizer.decode(dataset['source_tokens'], skip_special_tokens=True)

    if args.use_object_index:
        correct_index = correct_index[args.use_object_index]

    N_LAYERS = model.config.num_hidden_layers
    D_HEADS = model.config.num_attention_heads

    # Clean run (breaking into multiple tracer calls because otherwise we run into OOM)
    clean_hs = []
    with torch.no_grad():
        with model.trace(clean_prompt) as tracer:
            for layer_idx in range(N_LAYERS):
                clean_hs.append(
                    model.model.layers[layer_idx].output[0].save() if args.component == "resid" else \
                    model.model.layers[layer_idx].self_attn.o_proj.input.save() if args.component == "attn_out" else \
                    model.model.layers[layer_idx].mlp.output.save()
                )
            # Get logits from the lm_head.
            clean_logits = model.lm_head.output
            clean_logits = maybe_logit_soft_capping(clean_logits, model).save()
            # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
            clean_logit = clean_logits[0, -1, correct_index].sum().save()
            clean_logprob = torch.nn.functional.log_softmax(clean_logits[0, -1], dim=-1)[correct_index].sum().save()

    # Corrupted run
    with torch.no_grad():
        with model.trace(corrupted_prompt) as tracer:
            corrupted_logits = model.lm_head.output
            corrupted_logits = maybe_logit_soft_capping(corrupted_logits, model).save()
            # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
            corrupted_logit = corrupted_logits[0, -1, correct_index].sum().save()
            corrupted_logprob = torch.nn.functional.log_softmax(corrupted_logits[0, -1],dim=-1)[correct_index].sum().save()

    # Activation Patching Intervention
    patching_results = []
    # Iterate through all the layers
    bar = tqdm.tqdm(total=N_LAYERS*N_TOKENS)
    for layer_idx in range(N_LAYERS):
        _patching_results = []
        # Iterate through all tokens
        for token_idx in range(N_TOKENS):
            # Patching corrupted run at given layer and token
            with torch.no_grad():
                with model.trace(corrupted_prompt) as tracer:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    if args.component == "resid":
                        model.model.layers[layer_idx].output[0][:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]
                    elif args.component == "attn_out":
                        model.model.layers[layer_idx].self_attn.o_proj.input[:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]
                    elif args.component == "mlp_out":
                        model.model.layers[layer_idx].mlp.output[:, token_idx, :] = clean_hs[layer_idx][:, token_idx, :]
                    else:
                        raise NotImplementedError

                    patched_logits = model.lm_head.output
                    patched_logits = maybe_logit_soft_capping(patched_logits, model)
                    patched_logprob = torch.nn.functional.log_softmax(patched_logits[0, -1], dim=-1)[correct_index]
                    patched_logit =  patched_logits[0, -1, correct_index].sum()

                    if args.metric == "norm_logit_diff":
                        # Calculate the improvement in the correct token after patching.
                        patched_result = ((patched_logit - corrupted_logit) / (clean_logit - corrupted_logit)).save()
                    elif args.metric == "logit_ratio":
                        patched_result = (patched_logit / clean_logit).save()
                    elif args.metric == "logp_ratio":
                        patched_result = ((patched_logprob-clean_logprob) / clean_logprob).save()
                    elif args.metric == "logp":
                        patched_result = patched_logprob.save()
                    elif args.metric == "norm_logp_diff":
                        patched_result = ((patched_logprob - corrupted_logprob) / (clean_logprob - corrupted_logprob)).save()
                    else:
                        raise NotImplementedError

            if len(correct_index) > 1:
                patched_result = patched_result.sum()
            _patching_results.append(patched_result.detach().item())
            bar.update(1)
        patching_results.append(_patching_results)
    print(f"Clean logit: {clean_logit}")
    print(f"Corrupted logit: {corrupted_logit}")
    return patching_results


def activation_patching_residual_stream_batched(
    model: LanguageModel,
    dataset: Dataset,
    label_tokens: List[List[str]],
    args: argparse.Namespace,
):

    """
    patch from clean token position residual stream to respective counterfactual token positions, and measure
    accuracy of expected target objects
    """
    clean_tokens = dataset["base_tokens"]
    corrupted_tokens = dataset["source_tokens"]
    clean_pos = list(dataset["base_last_token_indices"])
    corrupted_pos = list(dataset["base_last_token_indices"])
    last_token_pos = np.array(dataset["base_last_token_indices"])

    N_LAYERS = model.config.num_hidden_layers
    N_HEADS = model.config.num_attention_heads
    N_DATA = len(clean_tokens)
    MAX_N_LABELS = 5 if args.debug else max(len(l) for l in label_tokens)

    # for i in range(N_DATA):
    #     assert ((isinstance(clean_pos[i], int) and isinstance(corrupted_pos[i], int)) or
    #             (len(clean_pos[i]) == len(corrupted_pos[i]))), "clean and corrupted token patching position counts should be the same"
    exp_scores = []
    # Iterate through all the layers
    for layer_idx in tqdm(range(N_LAYERS)):

        layer_scores = []

        # iterate through batches
        for batch_i in range(0, N_DATA, args.batch_size):
            batch_indices = range(batch_i, min(N_DATA, batch_i + args.batch_size))
            batch_corrupted_tokens = stupid_pad(corrupted_tokens[batch_indices], model.tokenizer)
            batch_clean_tokens = stupid_pad(clean_tokens[batch_indices], model.tokenizer)
            batch_clean_token_pos = [clean_pos[bi] for bi in batch_indices]
            batch_corrupted_token_pos = [corrupted_pos[bi] for bi in batch_indices]

            # Patching corrupted run at given layer and token
            torch.cuda.empty_cache()
            with torch.no_grad():
                with model.trace() as tracer:
                    # get corrupt residual
                    with tracer.invoke(batch_corrupted_tokens):
                        corrupt_layer_out = model.model.layers[layer_idx].output[0][:, batch_clean_token_pos].clone()
                        corrupt_logits = model.lm_head.output.save()

                    # patch into clean run
                    with tracer.invoke(batch_clean_tokens):
                        model.model.layers[layer_idx].output[0][:, batch_corrupted_token_pos] = corrupt_layer_out
                        logits = model.lm_head.output
                        last_token_logits = logits[range(len(batch_indices)), last_token_pos[batch_indices]]
                        logp
                        # topk_pred = last_token_logits.argsort(dim=-1, descending=True)[:,:MAX_N_LABELS].cpu().numpy().save()

            for i in range(len(batch_indices)):
                labels = label_tokens[batch_indices[i]]  # multiple target objects
                label_texts = [model.tokenizer.decode(l).strip().lower() for l in labels]
                topk_pred_texts = [model.tokenizer.decode(l).strip().lower() for l in topk_pred[i, :len(label_texts)]]
                if args.debug:
                    print(f"Corrupted Sentence: {model.tokenizer.decode(batch_corrupted_tokens[i])}")
                    print(f"Clean     Sentence: {model.tokenizer.decode(batch_clean_tokens[i])}")
                    topk_five_texts = [model.tokenizer.decode(l).strip().lower() for l in topk_pred[i, :5]]
                    print(f"Expected Labels: {label_texts}")
                    print(f"Top 5 prediction: {topk_five_texts}")

                if topk_pred_texts[0] in label_texts:
                    _argmax_correct_any += 1

                argmax_correct_full_batch = []
                topk_correct_full_batch = []
                for k, label_text in enumerate(label_texts):
                    argmax_correct_full_batch.append(1 if label_text == topk_pred_texts[0] else 0)
                    topk_correct_full_batch.append(1 if label_text in topk_pred_texts else 0)

                _argmax_correct_full.append(argmax_correct_full_batch)
                _topk_correct_full.append(topk_correct_full_batch)

        argmax_correct_any.append(_argmax_correct_any / N_DATA)
        argmax_correct_full.append(_argmax_correct_full)
        topk_correct_full.append(_topk_correct_full)

    return argmax_correct_any, argmax_correct_full, topk_correct_full


def main(args):
    """
    patching across heads X tokens
    """
    _, dataset, model = get_model_and_dataset(args)

    # dataset = dataset[0]

    if "cma_remove" in args.counterfactual_format:
        # for removal need to measure removed object logp
        correct_index = dataset['source_labels']
    else:
        correct_index = dataset['labels']

    # just take first example (newer data appended in front)
    clean_prompt = model.tokenizer.decode(dataset[0]['base_tokens'], skip_special_tokens=True)
    corrupted_prompt = model.tokenizer.decode(dataset[0]['source_tokens'], skip_special_tokens=True)
    print(f"clean_prompt: \n{clean_prompt}")
    print(f"corrupted_prompt: \n{corrupted_prompt}")
    print(f"correct_index for label {model.tokenizer.decode(correct_index[0])} = {correct_index[0]}")

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(f"{args.output_dir}", exist_ok=True)
    if args.single_example:
        ## below works for cma on 1put, have not tested for 1 remove
        patching_results = maybe_patch_or_load_cache(
            f"{args.output_dir}/activation_patching_metric{args.metric}_idx{args.num_samples}.npy",
            activation_patching_residual_stream_single_prompt,
            model=model, dataset=dataset[0], args=args, correct_index=correct_index[0],
        )
        token_labels = get_token_labels(dataset[0]['base_tokens'], dataset[0]['source_tokens'], model.tokenizer)
        fig = plot_patching_results(np.array(patching_results).squeeze(), token_labels,
                                    f"{args.model} all_ops={args.ops_order}, query_ops={args.query_ops_order=}, ctf={args.counterfactual_format}",
                                    labels={"x": "Tokens", "y": "Layer", "color": f"{args.metric}: {model.tokenizer.decode(correct_index[0])}"},
                                    centered=args.metric in ["norm_logit_diff", "norm_logp_diff"])
        os.makedirs(args.output_dir, exist_ok=True)
        plotly.offline.plot(fig, filename=f"{args.output_dir}/activation_patching_metric{args.metric}_idx{args.num_samples}.html",
                            auto_open=False)
    else:

        assert args.query_id is not None, "to compute batch results, need to ensure all datapoints are match up in length"
        patching_results = maybe_patch_or_load_cache(
            f"{args.output_dir}/activation_patching_metric{args.metric}_n{args.num_samples}.npy",
            activation_patching_residual_stream_batched,
            model=model, dataset=dataset, args=args, label_tokens=correct_index,
        )
        # for plotting, we would need to only plot important token positions
        token_labels = [] ## TODO



def add_args(parser: argparse.ArgumentParser):
    parser.add_argument('--component', help='where to patch', type=str, default="resid", choices=["resid", "attn_out", "mlp_out"])
    parser.add_argument('--single_example', action='store_true')
    parser.add_argument('--metric', type=str, choices=["norm_logit_diff", "logit_ratio", "logp_ratio", "logp", "norm_logp_diff"],default="norm_logit_diff")
    return parser

if __name__ == "__main__":
    parser = build_parser()
    add_args(parser)
    args = parser.parse_args()
    print(f"ARGS: {args}")
    post_arg_parse_fix(args)
    fix_random_seed(args.seed)
    main(args)




