import os
import pdb
import random
from itertools import product

from typing import Dict, List, Tuple, Any, Optional, Union, Callable
import argparse

from tqdm import tqdm
from jaxtyping import Float
import numpy as np
import torch
from torch import Tensor
import einops
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorWithPadding
import nnsight
from nnsight import CONFIG, LanguageModel, util
from transformer_lens import HookedTransformer

from transformers.pytorch_utils import find_pruneable_heads_and_indices

import sys
sys.path.append("../pp_experiment")
from utils import get_model_and_tokenizer, get_random_guess_baseline, fix_random_seed, get_random_circuit, get_circuit, get_mean_activations, eval_circuit_performance, get_root_exp_dir, MODEL_TO_SHORT, get_module
from run_patching import build_parser, post_arg_parse_fix, get_model_and_dataset, setup_nnsight


def eval_model_performance(model, dataloader):
    """
    Evaluate first token prediction correctness

    """
    total_count = 0
    argmax_correct_any = 0
    argmax_correct_full = []
    topk_correct_full = []

    with torch.no_grad():
        for output in tqdm(dataloader):
            for k, v in output.items():
                if v is not None and isinstance(v, torch.Tensor):
                    output[k] = v.to(model.device)

            logits = model(input_ids=output["base_tokens"]).logits
            for bi in range(len(output["labels"])):
                labels = output["labels"][bi]  # multiple target objects
                topk_pred = torch.argsort(logits[bi][output["base_last_token_indices"][bi]], descending=True)[:len(labels)].cpu().numpy()
                if (topk_pred[0] == labels).sum() > 0:
                    argmax_correct_any += 1

                argmax_correct_full_batch = []
                topk_correct_full_batch = []
                for k, label in enumerate(labels):
                    argmax_correct_full_batch.append(1 if topk_pred[0] == label > 0 else 0)
                    topk_correct_full_batch.append(1 if (topk_pred == label).sum() > 0 else 0)

                total_count += 1
                argmax_correct_full.append(argmax_correct_full_batch)
                topk_correct_full.append(topk_correct_full_batch)

    del logits
    torch.cuda.empty_cache()
    current_acc = round(argmax_correct_any / total_count, 2)
    return current_acc, argmax_correct_full, topk_correct_full


def eval_circuit_main(args: argparse.Namespace):
    """
    evaluate model performance, circuit performance and a random circuit (same size)
    performance
    """
    if args.remote:
        setup_nnsight()

    dataloader, dataset, model = get_model_and_dataset(args)
    circuit_components, heads_A, heads_B, heads_C, heads_D = get_circuit(model, args.circuit1_root_path, args.n_groupA, args.n_groupB, args.n_groupC, args.n_groupD, top_p=args.top_p)
    patch_circuit_components, patch_heads_A, patch_heads_B, patch_heads_C, patch_heads_D = get_circuit(model, args.circuit2_root_path, args.n_groupA, args.n_groupB, args.n_groupC, args.n_groupD, top_p=args.top_p)
    if args.patch_group == "C":
        circuit_components[2] = patch_circuit_components[2]
    elif args.patch_group == "D":
        circuit_components[-1] = patch_circuit_components[-1]
    elif args.patch_group == "A":
        # since both group A and B sharethe same token location, we need to
        # first remove heads in group A (that's not in group B), and then add
        # the heads in patch circuit
        for layer_idx, head in heads_A:
            layer = get_module(model, layer_idx)
            if [layer_idx, head] not in heads_B:
                circuit_components[0][layer].remove(head)
        for layer_idx, head in patch_heads_A:
            layer = get_module(model, layer_idx)
            circuit_components[0][layer].append(head)
    else: # group B
        for layer_idx, head in heads_B:
            layer = get_module(model, layer_idx)
            if [layer_idx, head] not in heads_A:
                circuit_components[0][layer].remove(head)
        for layer_idx, head in patch_heads_B:
            layer = get_module(model, layer_idx)
            circuit_components[0][layer].append(head)

    model_acc, model_argmax_full, model_topk_full = eval_model_performance(model, dataloader)
    if np.array([len(p) for p in model_argmax_full]).std() == 0:
        model_argmax_full = np.array(model_argmax_full).sum(0)/len(model_argmax_full)
        model_topk_full = np.array(model_topk_full).sum(0) /len(model_topk_full)
        print(f"Model Performance {model_acc}. Argmax accuracy by label index: {model_argmax_full}. TopK accuracy by label index: {model_topk_full} \n")
    else:
        print(f"Model Performance {model_acc}\n")

    # mean activation data also needs to be loaded filtered by operation orders
    mean_activations, modules = get_mean_activations(model=model, args=args, cache_dir=args.mean_activation_cache_path)

    circuit_acc, c_argmax_full, c_topk_full = eval_circuit_performance(
        model, dataloader, modules, circuit_components, mean_activations, ablate_non_vital_pos=not args.skip_ablate_non_vital_pos,
    )
    if np.array([len(p) for p in c_argmax_full]).std() == 0:
        c_argmax_full = np.array(c_argmax_full).sum(0)/len(c_argmax_full)
        c_topk_full = np.array(c_topk_full).sum(0) /len(c_topk_full)
        print(f"Circuit Performance {circuit_acc}. Argmax accuracy by label index: {c_argmax_full}. TopK accuracy by label index: {c_topk_full} \n")
    else:
        print(f"Circuit Performance {circuit_acc}\n")


    random_circuit_acc = 0
    random_circuit_argmax_full = []
    random_circuit_topk_full = []
    n_iters = 10
    for i in range(n_iters):
        random_circuit_components = get_random_circuit(model, args.n_groupA, args.n_groupB, args.n_groupC, args.n_groupD)
        r_circuit_acc, rc_argmax_full, rc_topk_full = eval_circuit_performance(
            model, dataloader, modules, random_circuit_components, mean_activations
        )
        random_circuit_acc += r_circuit_acc
        if np.array([len(p) for p in rc_argmax_full]).std() == 0:
            rc_argmax_full = np.array(rc_argmax_full).sum(0) / len(rc_argmax_full)
            rc_topk_full = np.array(rc_topk_full).sum(0) / len(rc_topk_full)
            print(f"Random Circuit {i} Performance {r_circuit_acc}. Argmax accuracy by label index: {rc_argmax_full}. TopK accuracy by label index: {rc_topk_full} \n")
            random_circuit_argmax_full.append(rc_argmax_full)
            random_circuit_topk_full.append(rc_topk_full)
        else:
            print(f"Random Circuit {i} Performance {r_circuit_acc}\n")
    random_circuit_acc = round(random_circuit_acc / n_iters, 2)
    print(f"Random Circuit Average Performance {random_circuit_acc}")
    if len(random_circuit_argmax_full) > 0:
        print(f"Argmax accuracy by label index: {np.array(random_circuit_argmax_full).mean(0)}")
        print(f"Topk accuracy by label index: {np.array(random_circuit_topk_full).mean(0)}")

    print(f"Faithfulness (Circuit): {round(circuit_acc / model_acc, 2)}")
    print(f"Faithfulness (Random Circuit): {round(random_circuit_acc / model_acc, 2)}")
    return


def add_args(parser: argparse.ArgumentParser):
    """
    circuit_root_path: str = "../outputs/nnsight_patch_no_op/gemma-2-2b/n200",
    percentage: float = 0.3,
    minimality_threshold: float = 0.01,
    """
    parser.add_argument('--circuit1_root_path', help='where base circuit is', type=str, default="../outputs/nnsight_patch_1put/gemma-2-2b/logp_notLastObj/n200")
    parser.add_argument('--circuit2_root_path', help='where patched-in circuit is', type=str, default="../outputs/nnsight_patch_noop/gemma-2-2b/logp_lastObjOnly/n200")
    parser.add_argument('--patch_group', help="which group to patch in from circuit2 to circuit1", type=str, default="A", choices=["A", "B", "C", "D"])
    parser.add_argument('--skip_ablate_non_vital_pos', help='skip ablation on non-essential tokens', action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument('--mean_activation_cache_path', help='path to cache mean activations', type=str, default="../outputs/nnsight_patch_noop/gemma-2-2b")
    return parser

if __name__ == "__main__":
    parser = add_args(build_parser())
    args = parser.parse_args()
    print(f"ARGS: {args}")
    post_arg_parse_fix(args)
    fix_random_seed(args.seed)
    eval_circuit_main(args)
