import os

os.environ["TOKENIZERS_PARALLELISM"] = "true"

import torch
import copy
import torch.nn as nn
import string
import random
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
import einops
from tqdm import tqdm
import gc
import torch.nn.functional as F
from load_global_variables import *
import multiprocessing
import concurrent

from transformers import AutoTokenizer, AutoModelForCausalLM

def clean():
    gc.collect()
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()

def generate(text, g, model=None, max_length=500):
    if model is None:
        print("Using cpu chat model")
        model = g.model_chat
    tokens = g.tokenizer(text, return_tensors="pt").to("cuda")
    print("Input:", g.tokenizer.batch_decode(tokens["input_ids"]))

    # model = to_gpus(model)
    print(model.device)
    print(tokens["input_ids"].device)
    print(tokens)
    out_tokens = model.generate(**tokens, max_length=max_length, do_sample=True)
    # model = to_cpu(model)
    
    out = g.tokenizer.batch_decode(out_tokens)
    return out[0]

def make_interpolation(m1, m2, alpha, m3=None, params="all", dtype = torch.float32):
    if m3 is None:
        if alpha > 1-1e-12:
            print("Loading pretrained model")
            model =  AutoModelForCausalLM.from_pretrained(model_name_dict["pretrained"], device_map = "auto", torch_dtype = dtype)
            # return torch.compile(model)
            return model
        
        m3 = AutoModelForCausalLM.from_pretrained(model_name_dict["chat"], device_map = "auto", torch_dtype = dtype)
        if alpha < 1e-12:
            print("Returning chat model")
            # return torch.compile(m3)
            return m3
        
    if params == "all":
        m1_params = m1.parameters()
        m2_params = m2.parameters()
        m3_params = m3.named_parameters()
    elif params == "layers":
        m1_params = m1.model.layers.parameters()
        m2_params = m2.model.layers.parameters()
        m3_params = m3.model.layers.named_parameters()
    elif params == "lm_head":
        m1_params = m1.lm_head.parameters()
        m2_params = m2.lm_head.parameters()
        m3_params = m3.lm_head.named_parameters()
    
    # accelerate.cpu_offload(m3, device_map)
    for p1, p2, (name, p3) in zip(m1_params, m2_params, m3_params):
        if p1.requires_grad:
            with torch.enable_grad():
                pp1 = p1.clone().to(p3.device)
                pp2 = p2.clone().to(p3.device)

                p3.data.copy_(alpha * pp1 + (1-alpha) * pp2)

    # accelerate.dispatch_model(m3, infer_auto_device_map(m3))
    return m3

def get_total_probs(logprobs, keys, g):
    keys_text = [g.tokenizer.decode(k) for k in keys]
    logprobs = {i: logprobs[0, i] for i in range(logprobs.shape[-1])}
    top_logprobs = sorted(list(logprobs.items()), key=lambda item: item[1], reverse=True)

    ans = {k: 0.0 for k in keys}
    for tok, logprob in top_logprobs:
        tok_txt = g.tokenizer.decode(tok)
        for target_tok in keys:
            target_tok_txt = g.tokenizer.decode(target_tok)
            if tok_txt.lower().strip() == target_tok_txt.lower().strip():
                # print(tok_txt)
                ans[target_tok] += torch.exp(logprob)
    
    total_prob = sum(ans.values())
    ans = {k: torch.log(v/total_prob).item() for k,v in ans.items()}
    # print(ans)
    return ans

def argmax_dict(d):
    max_pair = None
    for k,v in d.items():
        if max_pair is None or v > max_pair[1]:
            max_pair = (k, v)
    return max_pair

def get_logprob_next(text, model, g, chosen_tok = None, logit_bias={}):
    tokens = g.tokenizer(text, return_tensors="pt")

    # model = to_gpus(model)
    out = model(tokens["input_ids"]).logits[:, -1, :]
    out = out.to(torch.float64)
    out = torch.nn.functional.log_softmax(out, dim=-1)
    out_orig = out

    # print(out.shape)

    keys = list(range(0, out.shape[-1]))
    if logit_bias != {} and chosen_tok is None:
        keys = list(logit_bias.keys())
        out = out[:, keys]
        # print(out)

    # for k,v in logit_bias.items():
    #     out[:, k] = out[:, k] + v

    if chosen_tok is None:
        chosen_tok = out.argmax(axis=-1)

    logprob_next = out[0, chosen_tok]

    logprobs_selected = get_total_probs(out_orig, logit_bias.keys())
    if len(logprobs_selected.keys()) > 0:
        chosen_tok_text, logprob_next = argmax_dict(logprobs_selected)
    else:
        chosen_tok_text = keys[chosen_tok]

    # if keys[chosen_tok] in logprobs_selected.keys():
    #     # print("key is in logprobs")
    #     logprob_next = logprobs_selected[keys[chosen_tok]]
    # else:
    #     print(f"Chosen tok {chosen_tok}; keys[chosen_tok]: {keys[chosen_tok]}, logprobs_selected.keys(): {list(logprobs_selected.keys())}")

    return logprob_next, chosen_tok_text, logprobs_selected

def make_logit_bias(logit_bias_text_keys, g):
    return {g.tokenizer.encode(k): v for k,v in logit_bias_text_keys.items()}

def make_logit_bias_text_keys(logit_selected, g):
    return {g.tokenizer.decode(k): v for k,v in logit_selected.items()}

def generate_random_text(length):
  """Generates a random text sequence of the specified length."""
  letters_and_digits = string.ascii_letters + string.digits
  random_text = ''.join(random.choice(letters_and_digits) for _ in range(length))
  return random_text

def add_gibberish(text, gibberish, l):
    return gibberish[:l] + " " + text
    # if l > 0:
    #     return gibberish[:l] + " " + text
    # else:
    #     return text

def compute_derivative_wrt_alpha(text, g, model_interpolated=None, eps=1e-3, logit_bias={}):
    with torch.no_grad():
        model_interpolated = make_interpolation(g.model_pretrained, g.model_chat, alpha=0.0, m3=model_interpolated)
        logit_next_alpha_0, chosen_tok, logit_selected_alpha_0 = get_logprob_next(text, model_interpolated, logit_bias=logit_bias)

        model_interpolated = make_interpolation(g.model_pretrained, g.model_chat, alpha=eps, m3=model_interpolated)
        logit_next_alpha_eps, _, logit_selected_alpha_eps = get_logprob_next(text, model_interpolated, chosen_tok=chosen_tok, logit_bias=logit_bias)

        logit_selected_alpha_0 = make_logit_bias_text_keys(logit_selected_alpha_0)
        logit_selected_alpha_eps = make_logit_bias_text_keys(logit_selected_alpha_eps)
        tok_text = g.tokenizer.decode(chosen_tok)
        
        derivatives_dict = {}
        for k in logit_selected_alpha_eps.keys():
            derivatives_dict[k] = (logit_selected_alpha_eps[k] - logit_selected_alpha_0[k])/eps

        derivative = (logit_next_alpha_eps - logit_next_alpha_0)/eps
        # print(derivative)
        # print(derivative)
    
    return derivative, tok_text, derivatives_dict

def compute_probs_at_alpha(text, g, chosen_tok=None, model_interpolated=None, alpha=1e-3, logit_bias={}, interpolate_again=True):
    # if chosen_tok is None:
    #     assert alpha < 1e-12, "Need alpha to be 0 to not use chosen_tok"
    if interpolate_again:
        model_interpolated = make_interpolation(g.model_pretrained, g.model_chat, alpha=alpha, m3=model_interpolated)
    logit_next_alpha, chosen_tok, logit_selected_alpha = get_logprob_next(text, model_interpolated, chosen_tok=chosen_tok, logit_bias=logit_bias)

    logit_selected_alpha = make_logit_bias_text_keys(logit_selected_alpha)
    tok_text = g.tokenizer.decode(chosen_tok)

    return tok_text, logit_selected_alpha

def save_data_list(data_list, fname):
    with open(fname, "w") as json_file:
        json.dump(data_list, json_file, indent=4)

def compute_probs_at_alpha_many_texts(
        texts, 
        g,
        chosen_tok=None, 
        model_interpolated=None, 
        alpha=1e-3, 
        logit_bias={}, 
        interpolate_again=True,
        params="all",
):
    # if chosen_tok is None:
    #     assert alpha < 1e-12, "Need alpha to be 0 to not use chosen_tok"
    if interpolate_again:
        model_interpolated = make_interpolation(g.model_pretrained, g.model_chat, alpha=alpha, m3=model_interpolated, params=params)

    out = []
    for i, text in enumerate(texts):
        logit_next_alpha, chosen_tok, logit_selected_alpha = get_logprob_next(text, model_interpolated, chosen_tok=chosen_tok, logit_bias=logit_bias)

        logit_selected_alpha = make_logit_bias_text_keys(logit_selected_alpha)
        tok_text = g.tokenizer.decode(chosen_tok)

        info = {
            "idx": i,
            "params": params, 
            "alpha": alpha,
            "text": text,
            "ans": tok_text,
            **{f"logprob_{k}": v for k,v in logit_selected_alpha.items()}
        }

        out.append(info)

        print(json.dumps(info, indent=4))

    return out

def pretraining_attribution(hidden_states_pre, hidden_states_rlhf, w):
    w = torch.tensor(w, dtype=torch.float32).cuda()

    assert all([len(x.shape) == 2 for x in hidden_states_pre]) and all([len(x.shape) == 2 for x in hidden_states_rlhf])

    increments_rlhf = [
        (h_rlhf.cuda() - h_rlhf_prev.cuda()).to(torch.float32)
        for h_rlhf, h_rlhf_prev  in zip(hidden_states_rlhf[1:], hidden_states_rlhf[:-1])
    ]

    increments_pre = [
        (h_pre.cuda() - h_rlhf_prev.cuda()).to(torch.float32)
        for h_pre, h_rlhf_prev  in zip(hidden_states_pre[1:], hidden_states_rlhf[:-1])
    ]

    initial_state = hidden_states_rlhf[0].to(torch.float32).cuda()
    pretraining_circuit_outs = torch.concat([
        h_pre.unsqueeze(0) 
        for h_pre in increments_pre
    ])
    rlhf_circuit_outs = torch.concat([
        (h_rlhf-h_pre).unsqueeze(0) 
        for h_pre, h_rlhf in zip(increments_pre, increments_rlhf)
    ])

    initial_state_projected = einops.einsum(initial_state, w, "B H, N H -> B N")
    pretraining_circuits_projected = einops.einsum(pretraining_circuit_outs, w, "L B H, N H -> B N")
    rlhf_circuits_projected = einops.einsum(rlhf_circuit_outs, w, "L B H, N H -> B N")

    components = torch.stack([
        initial_state_projected,
        pretraining_circuits_projected,
        rlhf_circuits_projected
    ], dim=0).cpu()

    attributions = abs(components) / abs(components).sum(dim=0, keepdims=True)
    logits = components.sum(dim=0)

    components_isotropic = torch.stack([
        initial_state,
        pretraining_circuit_outs.sum(dim=0),
        rlhf_circuit_outs.sum(dim=0),
    ], dim=0).cpu()

    components_isotropic = abs(components_isotropic).sum(dim=-1)
    attributions_isotropic = components_isotropic / components_isotropic.sum(dim=0, keepdims=True)

    out = {
        "attr_initial_state": attributions[0],
        "attr_pretraining": attributions[1],
        "attr_rlhf": attributions[2],
        "isotropic_attr_initial_state": attributions_isotropic[0],
        "isotropic_attr_pretraining": attributions_isotropic[1],
        "isotropic_attr_rlhf": attributions_isotropic[2],
        "logits": logits,
    }

    del initial_state_projected, pretraining_circuits_projected, rlhf_circuits_projected
    del increments_pre, increments_rlhf, pretraining_circuit_outs, rlhf_circuit_outs

    clean()

    return out

def compute_next_token_pretraining_attribution(text, g):
    # del model_interpolated

    tokens = g.tokenizer(text, return_tensors="pt")

    print(tokens["input_ids"].shape)

    clean()

    model_interpolated = make_interpolation(g.model_pretrained, g.model_chat, 0)

    clean()

    with torch.no_grad():
        out_rlhf = model_interpolated(tokens["input_ids"], output_hidden_states=True)
        hidden_states_rlhf = out_rlhf.hidden_states

    del model_interpolated
    clean()

    model_interpolated = make_interpolation(g.model_pretrained, g.model_chat, 1)

    clean()

    with torch.no_grad():
        out_pretraining = model_interpolated(tokens["input_ids"], output_hidden_states=True, hidden_states_list=hidden_states_rlhf)
        hidden_states_pretraining = out_pretraining.hidden_states

    with torch.no_grad():
        rms_norm_weight = g.model_chat.model.norm.weight
        unembedding_matrix = g.model_chat.lm_head.weight
        normalized_unembedding = einops.einsum(unembedding_matrix, rms_norm_weight, "N H, H -> N H")

    out = pretraining_attribution(hidden_states_pretraining, hidden_states_rlhf, normalized_unembedding)
    true_logits = out_rlhf.logits[0, -1, :]
    true_logprobs = F.log_softmax(true_logits)
    true_probs = F.softmax(true_logits)
    df = pd.DataFrame([
        {
            "tok": tok,
            "id": g.tokenizer.decode(tok),
            "attr_pretraining": out["attr_pretraining"][0, tok].item(),
            "attr_rlhf": out["attr_rlhf"][0, tok].item(),
            "attr_initial_state": out["attr_initial_state"][0, tok].item(),
            "unnormalized_logit": out["logits"][0, tok].item(),
            "logit": out_rlhf.logits[0, -1, tok].item(),
            "logprob": true_logprobs[tok].item(),
            "prob": true_probs[tok].item()
        }
        for tok in range(g.tokenizer.vocab_size)
    ])
    return df

def batched_inference(model, g):
    tokenizer = g.tokenizer
    # TODO check
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side  = "left"

    @torch.no_grad()
    def inference(texts, **kwargs):
        try:
            gc.collect()
            torch.cuda.empty_cache()

            assert type(texts) == list, "Need list of texts for batched inference"
            B = len(texts)
            tokens = tokenizer(texts, padding=True, return_tensors="pt")

            # print("Decoding encoded tokens:", tokenizer.batch_decode(tokens["input_ids"]))

            # lens = torch.Tensor([
            #     attn.tolist().index(0) 
            #     if 0 in attn.tolist() 
            #     else len(attn.tolist())
            #     for attn in tokens.attention_mask 
            # ]).to(torch.int64)

            out = model(
                input_ids = tokens.input_ids.to("cuda"), 
                attention_mask = tokens.attention_mask.to("cuda"),
                **kwargs
            )
            hidden_states = out.hidden_states
            hidden_states_final = None
            if hidden_states is not None:
                hidden_states_final = [
                    # x[torch.arange(0, B), lens-1]
                    x[:, -1]
                    for x in hidden_states
                ]
            
            # true_logits = out.logits[torch.arange(B), lens-1]
            true_logits = out.logits[torch.arange(B), -1]

            return out, hidden_states, hidden_states_final, true_logits
        finally:
            # print("freeing memory")
            gc.collect()
            torch.cuda.empty_cache()
            gc.collect()
            torch.cuda.empty_cache()
    return inference

def inference_step(inference, these_texts, out):
    clean()
    # tokens = g.tokenizer(these_texts, return_tensors="pt")

    with torch.no_grad():
        p =  inference(these_texts, output_hidden_states=True)
        out_rlhf, hidden_states_rlhf, hidden_states_final_rlhf, true_logits_rlhf = p
        # hidden_states_rlhf = out_rlhf.hidden_states

    for k, text in enumerate(these_texts):
        out[text] = {
            "hidden_states_rlhf": [x[k].cpu().unsqueeze(0) for x in hidden_states_rlhf],
            "hidden_states_final_rlhf": [x[k].cpu().unsqueeze(0) for x in hidden_states_final_rlhf],
            "true_logits": true_logits_rlhf[k].cpu(),
        }

    del out_rlhf, hidden_states_rlhf, hidden_states_final_rlhf, true_logits_rlhf, p

    return out

def attribution_computation_step(inference, these_texts, normalized_unembedding, out):
    clean()
    # tokens = g.tokenizer(these_texts, return_tensors="pt")

    these_hidden_states_rlhf = []
    these_hidden_states_final_rlhf = []

    for i in range(len(out[these_texts[0]]["hidden_states_rlhf"])):
        these_hidden_states_rlhf.append(torch.concat([
            out[text]["hidden_states_rlhf"][i]
            for text in these_texts
        ], dim=0))
        these_hidden_states_final_rlhf.append(torch.concat([
            out[text]["hidden_states_final_rlhf"][i]
            for text in these_texts
        ], dim=0))

    clean()

    with torch.no_grad():
        p =  inference(
            these_texts, 
            output_hidden_states=True, 
            hidden_states_list=these_hidden_states_rlhf
        )
        _, _, hidden_states_final_pretraining, _ = p
        # hidden_states_rlhf = out_rlhf.hidden_states

    clean()

    attr_outs = pretraining_attribution(
        hidden_states_final_pretraining, 
        these_hidden_states_final_rlhf, 
        normalized_unembedding
    )

    del these_hidden_states_final_rlhf, these_hidden_states_rlhf, hidden_states_final_pretraining

    clean()

    return attr_outs

def _make_one_df(p):
    text, k, attr_outs, out = p
    g = GlobalModels(silent=True)
    true_logits = out[text]["true_logits"]
    true_logprobs = F.log_softmax(true_logits, dim=-1)
    true_probs = F.softmax(true_logits, dim=-1)

    topk_idx = torch.topk(true_logits, k=100)[1].tolist()
    avg_attr_rlhf = attr_outs["attr_rlhf"][k, :].mean().item()
    isotropic_dict = {key:val[k].item() for key,val in attr_outs.items() if "isotropic" in key}
    
    df = pd.DataFrame([
        {
            "tok": tok,
            "id": g.tokenizer.decode(tok),
            "attr_pretraining": attr_outs["attr_pretraining"][k, tok].item(),
            "attr_rlhf": attr_outs["attr_rlhf"][k, tok].item(),
            "attr_initial_state": attr_outs["attr_initial_state"][k, tok].item(),
            "unnormalized_logit": attr_outs["logits"][k, tok].item(),
            "logit": true_logits[tok].item(),
            "logprob": true_logprobs[tok].item(),
            "prob": true_probs[tok].item()
        }
        for tok in topk_idx
    ])

    df = df.sort_values("logit", ascending=False).head(100).reset_index(drop=True)

    g.close()
    
    return text, df, avg_attr_rlhf, isotropic_dict

def make_dfs(these_texts, attr_outs, progress_bar, out, g):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        results = list(executor.map(_make_one_df, [(text, k, attr_outs, out) for k, text in enumerate(these_texts)]))
    # with multiprocessing.Pool(processes=8) as pool:
    #     results = pool.map(_make_one_df, [(text, k, attr_outs, out) for k, text in enumerate(these_texts)])
    
    for text, df, avg_attr_rlhf, isotropic_dict in results:
        out[text]["df"] = df
        out[text]["avg_attr_rlhf"] = avg_attr_rlhf
        for key,val in isotropic_dict.items():
            # print(key, val)
            out[text][key] = val

    progress_bar.update(len(these_texts))
    return out

def compute_many_next_token_pretraining_attribution(texts, g, batch_size=None, dtype=torch.float32):
    model_interpolated = None
    try:
        out = {k: {} for k in texts}

        gc.collect()
        torch.cuda.empty_cache()

        if batch_size is None:
            batch_size = len(texts)

        with torch.no_grad():
            rms_norm_weight = g.model_chat.model.norm.weight
            unembedding_matrix = g.model_chat.lm_head.weight
            normalized_unembedding = einops.einsum(unembedding_matrix, rms_norm_weight, "N H, H -> N H")

        model_interpolated = make_interpolation(None, None, 0, dtype=dtype)
        inference = batched_inference(model_interpolated, g)

        progress_bar = tqdm(total=len(texts))
        for i in range(0, len(texts), batch_size):
            j = min(i + batch_size, len(texts))
            these_texts = texts[i:j]

            out = inference_step(inference, these_texts, out)

            gc.collect()
            torch.cuda.empty_cache()

            progress_bar.update(j-i)
        
        progress_bar.close()

        del inference
        del model_interpolated
        gc.collect()
        torch.cuda.empty_cache()

        model_interpolated = make_interpolation(None, None, 1, dtype=dtype)
        inference = batched_inference(model_interpolated, g)

        gc.collect()
        torch.cuda.empty_cache()

        progress_bar = tqdm(total=len(texts))
        for i in range(0, len(texts), batch_size):
            j = min(i + batch_size, len(texts))
            these_texts = texts[i:j]

            gc.collect()
            torch.cuda.empty_cache()

            print("attributions")
            attr_outs = attribution_computation_step(inference, these_texts, normalized_unembedding, out)
            print("making dfs")
            out = make_dfs(these_texts, attr_outs, progress_bar, out, g)
            # print({k: v.keys() for k,v in out.items()})

        progress_bar.close()
        gc.collect()
        torch.cuda.empty_cache()

        del model_interpolated
        del inference

        gc.collect()
        torch.cuda.empty_cache()

        return [
            {
                "prompt": k, 
                "df": v["df"],
                "avg_attr_rlhf": v["avg_attr_rlhf"],
                "isotropic_attr_initial_state": v["isotropic_attr_initial_state"],
                "isotropic_attr_pretraining": v["isotropic_attr_pretraining"],
                "isotropic_attr_rlhf": v["isotropic_attr_rlhf"],

            } 
            for k,v in out.items()
        ]
    finally:
        del rms_norm_weight, unembedding_matrix, normalized_unembedding 
        model_interpolated = None
        gc.collect()
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.empty_cache()

