import plotly.express as px
import transformer_lens.utils as utils
import numpy as np
import torch

from functools import partial
from typing import List, Optional, Union
import os, sys
import configparser

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
import circuitsvis as cv
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer
from transformers import LlamaForCausalLM, LlamaTokenizer, AutoConfig, AutoModelForCausalLM

from patching_utils import *
def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()
    

## Logit Attribution
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):

    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)

    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:

        return answer_logit_diff
    else:
        return answer_logit_diff.mean()
    

def assess_logit_diff(tokenizer, model, prompts, answers, device):
    answer_tokens = torch.tensor([[to_single_token(tokenizer, a[0]), to_single_token(tokenizer, a[1])] for a in answers]).reshape(len(answers), 2).to(device)
  
    #Get length of prompts so we can pad
    prompt_lens = [len(to_tokens(tokenizer,prompt, device)) for prompt in prompts]
    max_len = max(prompt_lens)

    tokens = torch.stack([to_tokens(tokenizer, prompt,device, True, int(max_len-prompt_lens[i])) for i, prompt in enumerate(prompts)])
    
    original_logits, cache = model.run_with_cache(tokens)

    per_prompt_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True).detach()

    return answer_tokens, per_prompt_diff


def get_logit_diff(logits, answer_tokens):
    if len(logits.shape)==3:
        # Get final logits only
        logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_tokens[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_tokens[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()




def residual_stack_to_logit_diff(
    residual_stack: Float[torch.Tensor, "components batch d_model"],
    cache: ActivationCache,
    len_prompts,
    logit_diff_directions
) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack, layer=-1, pos_slice=-1
    )
    return einsum(
        "... batch d_model, batch d_model -> ...",
        scaled_residual_stack,
        logit_diff_directions,
    ) / len_prompts



def get_tokens_cache(tokenizer, model, prompts, answer_tokens, device):
    prompt_lens = [len(to_tokens(tokenizer, prompt, device)) for prompt in prompts]
    max_len = max(prompt_lens)
    tokens = torch.stack([to_tokens(tokenizer, prompt,device, True, int(max_len - prompt_lens[i])) for i, prompt in enumerate(prompts)])
    
    original_logits, cache = model.run_with_cache(tokens)
    answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)

    logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
    )

    return tokens, original_logits, cache, answer_residual_directions, logit_diff_directions

#LLama tokenization
##Llama tokenizer helper functions

def decode_single_token(tokenizer, integer):
    # To recover whether the tokens begins with a space, we need to prepend a token to avoid weird start of string behaviour
    return tokenizer.decode([891, integer])[1:]
def to_str_tokens(tokenizer, tokens, prepend_bos=True):
    if isinstance(tokens, str):
        tokens = to_tokens(tokenizer, tokens)
    if isinstance(tokens, torch.Tensor):
        if len(tokens.shape)==2:
            #assert tokens.shape[0]==1
            tokens = tokens[0]
        tokens = tokens.tolist()
    if prepend_bos:
        return [decode_single_token(tokenizer, token) for token in tokens]
    else:
        return [decode_single_token(tokenizer, token) for token in tokens[1:]]

def to_string(tokens):
    if isinstance(tokens, torch.Tensor):
        if len(tokens.shape)==2:
            assert tokens.shape[0]==1
            tokens = tokens[0]
        tokens = tokens.tolist()
    return tokenizer.decode([891]+tokens)[1:]

def to_tokens(tokenizer, string, device, prepend_bos=True, pad_left = 0):
    string = "|"+string
    # The first two elements are always [BOS (1), " |" (891)]
    tokens = tokenizer.encode(string)
    if prepend_bos:
        if pad_left > 0:
            pad = list(np.repeat(tokenizer.pad_token_id, pad_left))
            return torch.tensor(tokens[:1] + pad + tokens[2:]).to(device)
        else:
            return torch.tensor(tokens[:1] + tokens[2:]).to(device)
    else:
        if pad_left > 0:
            pad = torch.tensor(np.repeat(tokenizer.pad_token_id, pad_left))
            return torch.tensor(pad + tokens[2:]).to(device)
        else:
            return torch.tensor(tokens[2:]).to(device)

def to_single_token(tokenizer, string):
    #assert string[0]==" ", f"Expected string to start with space, got {string}"
    #string = string[1:]
    tokens = tokenizer.encode(string)[-1]
    #assert len(tokens)==1, f"Expected 1 tokens, got {len(tokens)}: {tokens}"
    return tokens


def get_per_head_logit_diffs(model,cache, len_prompts, logit_diff_directions):
    per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
    )

    per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache, len_prompts, logit_diff_directions)
    per_head_logit_diffs = einops.rearrange(
        per_head_logit_diffs,
        "(layer head_index) -> layer head_index",
        layer=model.cfg.n_layers,
        head_index=model.cfg.n_heads,
    )

    return per_head_residual, per_head_logit_diffs



def visualize_attention_patterns(
    model,
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: ActivationCache,
    local_tokens: torch.Tensor,
    title: Optional[str] = "",
    max_width: Optional[int] = 700,
) -> str:
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: List[str] = []
    patterns: List[Float[torch.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    batch_index = 0

    for head in heads:
        # Set the label
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])

    # Convert the tokens to strings (for the axis labels)
    str_tokens = model.to_str_tokens(local_tokens)

    # Combine the patterns into a single tensor
    patterns: Float[torch.Tensor, "head_index dest_pos src_pos"] = torch.stack(
        patterns, dim=0
    )

    # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
    plot = attention_heads(
        attention=patterns, tokens=str_tokens, attention_head_names=labels
    ).show_code()

    # Display the title
    title_html = f"<h2>{title}</h2><br/>"

    # Return the visualisation as raw code
    return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"