import torch
import einops
from .utils import clean, get_hidden_states, _make_one_df
from tqdm import tqdm
import concurrent

def tokenwise_attributions(initial_state, pretraining_circuit_outs, finetuned_circuit_outs, w=None, **kwargs):
    w = torch.tensor(w, dtype=torch.float32).cpu()

    initial_state_projected = einops.einsum(initial_state.cpu(), w, "B H, N H -> B N")
    pretraining_circuits_projected = einops.einsum(pretraining_circuit_outs.cpu(), w, "L B H, N H -> L B N")
    finetuned_circuits_projected = einops.einsum(finetuned_circuit_outs.cpu(), w, "L B H, N H -> L B N")

    components_coarse = torch.stack([
        initial_state_projected,
        pretraining_circuits_projected.sum(dim=0),
        finetuned_circuits_projected.sum(dim=0),
    ], dim=0)

    components_finegrained = torch.stack([
        abs(initial_state_projected),
        abs(pretraining_circuits_projected).sum(dim=0),
        abs(finetuned_circuits_projected).sum(dim=0),
    ], dim=0)

    # 2 L B N
    components_layerwise = torch.stack([
        abs(pretraining_circuits_projected),
        abs(finetuned_circuits_projected),
    ], dim=0)

    attributions_coarse = abs(components_coarse) / abs(components_coarse).sum(dim=0, keepdims=True)
    attributions_finegrained = abs(components_finegrained) / abs(components_finegrained).sum(dim=0, keepdims=True)

    attributions_layerwise = (components_layerwise / components_layerwise.sum(dim=0, keepdims=True)).mean(dim=1)

    logits = components_coarse.sum(dim=0)
    logits_initial_state = components_coarse[0]
    logits_pretraining_component = components_coarse[1]
    logits_finetuning_component = components_coarse[2]

    outs = {
        "tokenwise_coarse_attr_initial_state": attributions_coarse[0],
        "tokenwise_coarse_attr_pretraining": attributions_coarse[1],
        "tokenwise_coarse_attr_finetuned": attributions_coarse[2],
        "tokenwise_finegrained_attr_initial_state": attributions_finegrained[0],
        "tokenwise_finegrained_attr_pretraining": attributions_finegrained[1],
        "tokenwise_finegrained_attr_finetuned": attributions_finegrained[2],
        "tokenwise_layerwise_attr_pretraining": attributions_layerwise[0],
        "tokenwise_layerwise_attr_finetuned": attributions_layerwise[1],
        "tokenwise_unnormalized_logits": logits,
        "tokenwise_logits_initial_state": logits_initial_state,
        "tokenwise_logits_pretraining_component": logits_pretraining_component,
        "tokenwise_logits_finetuning_component": logits_finetuning_component,
    }

    return outs

def isotropic_attributions(initial_state, pretraining_circuit_outs, finetuned_circuit_outs, **kwargs):
    components_isotropic_coarse0 = torch.stack([
        initial_state,
        pretraining_circuit_outs.sum(dim=0),
        finetuned_circuit_outs.sum(dim=0),
    ], dim=0).cpu()

    components_isotropic_finegrained = torch.stack([
        abs(initial_state),
        abs(pretraining_circuit_outs).sum(dim=0),
        abs(finetuned_circuit_outs).sum(dim=0),
    ], dim=0).cpu()

    # 2 L B H
    components_isotropic_layerwise = torch.stack([
        pretraining_circuit_outs,
        finetuned_circuit_outs,
    ], dim=0).cpu()

    components_isotropic_coarse = abs(components_isotropic_coarse0).sum(dim=-1)
    attributions_isotropic_coarse = components_isotropic_coarse / components_isotropic_coarse.sum(dim=0, keepdims=True)

    components_isotropic_coarse_l2_sq = (components_isotropic_coarse0 ** 2).sum(dim=-1)
    attributions_isotropic_coarse_l2_sq = components_isotropic_coarse_l2_sq / components_isotropic_coarse_l2_sq.sum(dim=0, keepdims=True)

    components_isotropic_coarse_l2 = (components_isotropic_coarse0 ** 2).sum(dim=-1) ** 0.5
    attributions_isotropic_coarse_l2 = components_isotropic_coarse_l2 / components_isotropic_coarse_l2.sum(dim=0, keepdims=True)

    components_isotropic_finegrained = abs(components_isotropic_finegrained).sum(dim=-1)
    attributions_isotropic_finegrained = components_isotropic_finegrained / components_isotropic_finegrained.sum(dim=0, keepdims=True)

    # print(components_isotropic_layerwise.sum(dim=0, keepdims=True))
    attributions_isotropic_layerwise = (abs(components_isotropic_layerwise) / (1e-12+abs(components_isotropic_layerwise).sum(dim=0, keepdims=True))).mean(dim=1).mean(dim=-1)

    cum_components_isotropic_layerwise = torch.cumsum(components_isotropic_layerwise, dim=1)
    cum_components_isotropic_layerwise = abs(cum_components_isotropic_layerwise).sum(dim=-1)
    cum_alphas_isotropic_layerwise = cum_components_isotropic_layerwise / cum_components_isotropic_layerwise.sum(dim=0, keepdims=True)
    sup_alphas = torch.max(cum_alphas_isotropic_layerwise[1, :, :], dim=0)[0]
    # print(sup_alphas.shape)

    outs = {
        "isotropic_coarse_attr_initial_state": attributions_isotropic_coarse[0],
        "isotropic_coarse_attr_pretraining": attributions_isotropic_coarse[1],
        "isotropic_coarse_attr_finetuned": attributions_isotropic_coarse[2],

        "isotropic_coarse_l2_attr_initial_state": attributions_isotropic_coarse_l2[0],
        "isotropic_coarse_l2_attr_pretraining": attributions_isotropic_coarse_l2[1],
        "isotropic_coarse_l2_attr_finetuned": attributions_isotropic_coarse_l2[2],

        "isotropic_coarse_l2_sq_attr_initial_state": attributions_isotropic_coarse_l2_sq[0],
        "isotropic_coarse_l2_sq_attr_pretraining": attributions_isotropic_coarse_l2_sq[1],
        "isotropic_coarse_l2_sq_attr_finetuned": attributions_isotropic_coarse_l2_sq[2],

        "isotropic_unnormalized_attr_initial_state": components_isotropic_coarse[0],
        "isotropic_unnormalized_attr_pretraining": components_isotropic_coarse[1],
        "isotropic_unnormalized_attr_finetuned": components_isotropic_coarse[2],

        "isotropic_finegrained_attr_initial_state": attributions_isotropic_finegrained[0],
        "isotropic_finegrained_attr_pretraining": attributions_isotropic_finegrained[1],
        "isotropic_finegrained_attr_finetuned": attributions_isotropic_finegrained[2],

        "isotropic_layerwise_attr_pretraining": attributions_isotropic_layerwise[0],
        "isotropic_layerwise_attr_finetuned": attributions_isotropic_layerwise[1],

        "isotropic_sup_attr_finetuned": sup_alphas,
    }

    return outs

class Attributions:
    def __init__(self, models, args):
        self.args = args
        self.models = models
        self.itos = self.models.itos

    def _compute_layerwise_circuit_outputs(self, hidden_states_pre, hidden_states_finetuned):
        assert all([len(x.shape) == 2 for x in hidden_states_pre]) and all([len(x.shape) == 2 for x in hidden_states_finetuned])

        increments_finetuned = [
            (h_finetuned.cuda() - h_finetuned_prev.cuda()).to(torch.float32)
            for h_finetuned, h_finetuned_prev  in zip(hidden_states_finetuned[1:], hidden_states_finetuned[:-1])
        ]

        increments_pre = [
            (h_pre.cuda() - h_finetuned_prev.cuda()).to(torch.float32)
            for h_pre, h_finetuned_prev  in zip(hidden_states_pre[1:], hidden_states_finetuned[:-1])
        ]

        initial_state = hidden_states_finetuned[0].to(torch.float32).cuda()
        pretraining_circuit_outs = torch.concat([
            h_pre.unsqueeze(0) 
            for h_pre in increments_pre
        ])
        finetuned_circuit_outs = torch.concat([
            (h_finetuned-h_pre).unsqueeze(0) 
            for h_pre, h_finetuned in zip(increments_pre, increments_finetuned)
        ])

        return initial_state, pretraining_circuit_outs, finetuned_circuit_outs


    def pretraining_attribution(self, hidden_states_pre, hidden_states_finetuned, w):
        initial_state, pretraining_circuit_outs, finetuned_circuit_outs = self._compute_layerwise_circuit_outputs(hidden_states_pre, hidden_states_finetuned)

        out = {}
        for attr_fn in [tokenwise_attributions, isotropic_attributions]:
            new_out = attr_fn(initial_state, pretraining_circuit_outs, finetuned_circuit_outs, w=w)
            for k, v in new_out.items():
                out[k] = v

        clean()

        return out

    def finetuned_forward_pass(self, these_texts, out, **kwargs):
        clean()
        # tokens = g.tokenizer(these_texts, return_tensors="pt")

        with torch.no_grad():
            model_output_finetuned = self.models.batched_inference(
                these_texts, 
                output_hidden_states=True, 
                debug=self.args.debug,
                **kwargs
            )

            hidden_states_finetuned = model_output_finetuned.hidden_states
            hidden_states_final_finetuned = None
            if hidden_states_finetuned is not None:
                hidden_states_final_finetuned = [
                    # x[torch.arange(0, B), lens-1]
                    x[:, -1]
                    for x in hidden_states_finetuned
                ]
            
            # true_logits = out.logits[torch.arange(B), lens-1]
            true_logits_finetuned = model_output_finetuned.logits[:, -1]


        for k, text in enumerate(these_texts):
            out[text] = {
                "hidden_states_finetuned": [x[k].cpu().unsqueeze(0) for x in hidden_states_finetuned],
                "hidden_states_final_finetuned": [x[k].cpu().unsqueeze(0) for x in hidden_states_final_finetuned],
                "true_logits": true_logits_finetuned[k].cpu(),
            }

        del model_output_finetuned, hidden_states_finetuned, hidden_states_final_finetuned, true_logits_finetuned

        return out
    

    def attribution_computation_step(self, these_texts, normalized_unembedding, out):
        clean()
        # tokens = g.tokenizer(these_texts, return_tensors="pt")

        these_hidden_states_finetuned, these_hidden_states_final_finetuned = get_hidden_states(these_texts, out)

        clean()

        model_output_pretraining = self.models.batched_inference(
            these_texts,
            output_hidden_states=True,
            use_pretrained_gpu=True,
            hidden_states_list=these_hidden_states_finetuned,
            use_cache=False,
        )

        hidden_states_final_pretraining = [
            # x[torch.arange(0, B), lens-1]
            x[:, -1]
            for x in model_output_pretraining.hidden_states
        ]

        clean()

        attr_outs = self.pretraining_attribution(
            hidden_states_final_pretraining, 
            these_hidden_states_final_finetuned, 
            normalized_unembedding
        )

        del model_output_pretraining, these_hidden_states_final_finetuned, these_hidden_states_finetuned, hidden_states_final_pretraining

        clean()

        return attr_outs
    
    def make_dfs(self, these_texts, attr_outs, progress_bar, out):
        with concurrent.futures.ThreadPoolExecutor() as executor:
            results = list(executor.map(_make_one_df, [(text, k, attr_outs, self.itos, out) for k, text in enumerate(these_texts)]))
        # with multiprocessing.Pool(processes=8) as pool:
        #     results = pool.map(_make_one_df, [(text, k, attr_outs, out) for k, text in enumerate(these_texts)])
        
        for text, tokenwise_df, consolidated_df in results:
            out[text]["tokenwise_df"] = tokenwise_df
            out[text]["consolidated_df"] = consolidated_df

        progress_bar.update(len(these_texts))
        return out

    def compute(self, texts, batch_size=None, device_map="balanced", _external_pbar=None, **kwargs):
        clean()

        out = {k: {} for k in texts}

        if batch_size is None:
            batch_size = len(texts)

        normalized_unembedding = self.models.get_normalized_unembedding("finetuned")

        self.models.load_gpu("finetuned", device_map=device_map, force=False)

        progress_bar = tqdm(total=len(texts)) if _external_pbar is None else _external_pbar
        for i in range(0, len(texts), batch_size):
            j = min(i + batch_size, len(texts))
            these_texts = texts[i:j]
            out = self.finetuned_forward_pass(these_texts, out, **kwargs)
            clean()

            if _external_pbar is None:
                progress_bar.update(j-i)

        if _external_pbar is None:
            progress_bar.close()

        # self.models.clean_loaded_model()
        # self.models.load_gpu("pretrained", device_map=device_map)

        clean()

        progress_bar = tqdm(total=len(texts)) if _external_pbar is None else _external_pbar
        for i in range(0, len(texts), batch_size):
            j = min(i + batch_size, len(texts))
            these_texts = texts[i:j]
            attr_outs = self.attribution_computation_step(these_texts, normalized_unembedding, out)
            out = self.make_dfs(these_texts, attr_outs, progress_bar, out)

        # self.models.clean_loaded_model()

        if _external_pbar is None:
            progress_bar.close()
        
        clean()

        final_out = [
            {
                "prompt": k, 
                "tokenwise_df": v["tokenwise_df"],
                "consolidated_df": v["consolidated_df"],
            } 
            for k,v in out.items()
        ]

        return final_out

    