import os
from typing import Dict, Any, Optional, Tuple, Union, List
import functools
import math
import warnings
from pathlib import Path
import json

import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

from patching_gemma import logger
from patching_gemma.models.llama3 import Llama3Model
from patching_gemma.models.utils.data_processing.collate_function import collate_fn_everything
from patching_gemma.models.utils.data_processing.dataset import RequestDataset

class Llama3AblateEdges(Llama3Model):
    def run(self, task, limit, batch_size, affect_whom, log_dir) -> None:
        ablate_edges_between_types = affect_whom
        self.ablate_edges_between_types = ablate_edges_between_types
        self.task = task
        self.model_logs["ablate_edges_between_types"] = ablate_edges_between_types
        self.model_logs["not_ablate"] = [(tp1, tp2) for tp1 in range(self.task.TOKEN_TYPES) for tp2 in range(self.task.TOKEN_TYPES) if (tp1, tp2) not in ablate_edges_between_types]
        assert task.can_be_token_separable

        dataset = RequestDataset(task, limit, corrupted=True, tokenizer=self.tokenizer)
        self.model_logs["dataset_examples"] = [dataset[i] for i in range(len(dataset))]
        self.num_requests = len(dataset)

        self.generate(dataset, batch_size)

    def generate(self, dataset, batch_size) -> None:
        self.generate_mode = True
        logger.debug("Start generate part")
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                                collate_fn=functools.partial(collate_fn_everything,
                                                             padding_left=True, use_corrupted_activations=True,
                                                             tokenizer=self.tokenizer),
                                num_workers=len(os.sched_getaffinity(0)) - 1)
        sum_accuracy = 0
        sum_loss = 0
        num_predictive_inds_total = 0
        continuations = []
        examples = []
        is_correct = []
        all_targets = []

        for batch in dataloader:
            self.corrupted_activations = {
                "input_to_attn_per_type": {layer: {head: {tp: None
                                    for tp in range(self.task.TOKEN_TYPES)}
                                    for head in range(self.model.model.layers[0].self_attn.num_heads)}
                                    for layer in range(len(self.model.model.layers))},
            }
            inputs = batch[0].to("cuda") # TODO: change to handle different model device
            full_inputs = batch[5].to("cuda") # TODO: change to handle different model device
            full_corrupted_contexts = batch[6].to("cuda") # TODO: change to handle different model device
            corrupted_contexts = batch[4].to("cuda") # TODO: change to handle different model device
            target_ids = batch[7].to("cuda")
            lens = batch[2]
            targets = batch[3]

            if len(examples) < 3:
                for i in range(inputs["input_ids"].shape[0]):
                    if len(examples) < 3:
                        examples.append(self.tokenizer.decode(inputs["input_ids"][i].detach().cpu()))

            self.tp_inds = self.task.get_token_types_for_contexts_with_targets(self.tokenizer, full_inputs["input_ids"].detach().cpu())
            self.corrupted_tp_inds = self.task.get_token_types_for_contexts_with_targets(self.tokenizer, full_corrupted_contexts["input_ids"].detach().cpu())
            assert (self.corrupted_tp_inds == self.tp_inds).all()

            self.is_corrupted_run = True
            with torch.no_grad():
                self.model(**full_corrupted_contexts, use_cache=False)

            self.is_corrupted_run = False
            predictive_inds = ((self.tp_inds == self.task.TARGET_TYPE) | (self.tp_inds == self.task.LAST_SEP_TYPE))
            predictive_inds[torch.arange(predictive_inds.shape[0]), (predictive_inds.cumsum(dim=-1) * predictive_inds).argmax(dim=-1)] = 0
            num_predictive_inds_total += torch.count_nonzero(predictive_inds).item()
            with torch.no_grad():
                logits = self.model(**full_inputs, use_cache=False).logits
                loss = self.task.loss(logits, target_ids.to(logits.device), lens, self.tp_inds)
                sum_loss += loss.item() * torch.count_nonzero(predictive_inds).item()

            for layer in range(len(self.model.model.layers)):
                for head in range(self.model.model.layers[0].self_attn.num_heads):
                    for tp in range(self.task.TOKEN_TYPES):
                        del self.corrupted_activations["input_to_attn_per_type"][layer][head][tp]
            self.corrupted_activations = None
            self.corrupted_activations = {
                "input_to_attn_per_type": {layer: {head: {tp: None
                                    for tp in range(self.task.TOKEN_TYPES)}
                                    for head in range(self.model.model.layers[0].self_attn.num_heads)}
                                    for layer in range(len(self.model.model.layers))},
            }

            self.tp_inds = self.task.get_token_types_for_contexts(self.tokenizer, inputs["input_ids"].detach().cpu())
            self.corrupted_tp_inds = self.task.get_token_types_for_contexts(self.tokenizer, corrupted_contexts["input_ids"].detach().cpu())
            assert (self.corrupted_tp_inds == self.tp_inds).all()

            self.is_corrupted_run = True
            with torch.no_grad():
                self.model(**corrupted_contexts, use_cache=False)

            self.is_corrupted_run = False
            out = self.model.generate(**inputs, max_new_tokens=7)
            
            for i in range(out.shape[0]):
                continuation = self.tokenizer.decode(out[i][inputs["input_ids"][i].shape[0]:])
                continuations.append(continuation)
                is_correct.append(int(continuation.strip().startswith(targets[i])))
                all_targets.append(targets[i])
                sum_accuracy += int(continuation.strip().startswith(targets[i]))

            del inputs
            for layer in range(len(self.model.model.layers)):
                for head in range(self.model.model.layers[0].self_attn.num_heads):
                    for tp in range(self.task.TOKEN_TYPES):
                        del self.corrupted_activations["input_to_attn_per_type"][layer][head][tp]
            self.corrupted_activations = None
            torch.cuda.empty_cache()

        # self.model_logs["first_3_continuations"] = continuations
        self.model_logs["first_3_loader_generate_exampels"] = examples
        self.model_logs["accuracy"] = sum_accuracy / self.num_requests
        self.model_logs["loss"] = sum_loss / num_predictive_inds_total
        self.model_logs["targets"] = all_targets
        self.model_logs["predictions"] = continuations
        self.model_logs["is_correct_prediction"] = is_correct
        

    def break_into(self) -> None:
        self.hook_handles = []
        self.prev_forwards = {"decoder": [], "attn": []}

        for layer in range(len(self.model.model.layers)):
            self.prev_forwards["attn"].append(self.model.model.layers[layer].self_attn.forward)
            forward_partial = functools.partial(self.attn_forward, layer=layer,
                                                self=self.model.model.layers[layer].self_attn,
                                                llama_model=self)
            self.model.model.layers[layer].self_attn.forward = forward_partial

            self.prev_forwards["decoder"].append(self.model.model.layers[layer].forward)
            forward_partial = functools.partial(self.decoder_forward, layer=layer,
                                                self=self.model.model.layers[layer],
                                                llama_model=self)
            self.model.model.layers[layer].forward = forward_partial

    def break_out(self) -> None:
        for layer, f in enumerate(self.prev_forwards):
            forward_partial = functools.partial(self.prev_forwards["attn"][layer],
                                                self=self.model.model.layers[layer].self_attn)
            self.model.model.layers[layer].self_attn.forward = forward_partial
            forward_partial = functools.partial(self.prev_forwards["decoder"][layer],
                                                self=self.model.model.layers[layer])
            self.model.model.layers[layer].forward = forward_partial
        for h in self.hook_handles:
            h.remove()

    @staticmethod
    def attn_forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Any] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        layer: Optional[int] = None,
        llama_model: Optional[Any] = None,
        **kwargs: Any,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

        assert self.config.pretraining_tp == 1
        bsz, q_len, _ = hidden_states[0][0].size()
        extended_bsz = bsz * llama_model.task.TOKEN_TYPES
        original_bsz = bsz

        query_states = {tp: torch.empty(bsz, q_len, self.num_heads, self.head_dim,
                                        device=hidden_states[self.num_heads - 1][llama_model.task.TOKEN_TYPES - 1].device, dtype=hidden_states[self.num_heads - 1][llama_model.task.TOKEN_TYPES - 1].dtype)
                        for tp in range(llama_model.task.TOKEN_TYPES)}
        key_states = {tp: torch.empty(bsz, q_len, self.num_heads, self.head_dim,
                                    device=hidden_states[self.num_heads - 1][llama_model.task.TOKEN_TYPES - 1].device, dtype=hidden_states[self.num_heads - 1][llama_model.task.TOKEN_TYPES - 1].dtype)
                        for tp in range(llama_model.task.TOKEN_TYPES)}
        value_states = {tp: torch.empty(bsz, q_len, self.num_heads, self.head_dim,
                                        device=hidden_states[self.num_heads - 1][llama_model.task.TOKEN_TYPES - 1].device, dtype=hidden_states[self.num_heads - 1][llama_model.task.TOKEN_TYPES - 1].dtype)
                        for tp in range(llama_model.task.TOKEN_TYPES)}

        for head in range(self.num_heads):
            for tp in range(llama_model.task.TOKEN_TYPES):
                query_states[tp][:, :, head, :] = self.q_proj(hidden_states[head][tp]).view(bsz, q_len, self.num_heads, self.head_dim)[:, :, head, :]
                key_states[tp][:, :, head, :] = self.k_proj(hidden_states[head][tp]).view(bsz, q_len, self.num_key_value_heads, self.head_dim)[:, :, head // self.num_key_value_groups, :]
                value_states[tp][:, :, head, :] = self.v_proj(hidden_states[head][tp]).view(bsz, q_len, self.num_key_value_heads, self.head_dim)[:, :, head // self.num_key_value_groups, :]

        query_states = torch.vstack([query_states[tp] for tp in range(llama_model.task.TOKEN_TYPES)]).transpose(1, 2)
        key_states = torch.vstack([key_states[tp] for tp in range(llama_model.task.TOKEN_TYPES)]).transpose(1, 2)
        value_states = torch.vstack([value_states[tp] for tp in range(llama_model.task.TOKEN_TYPES)]).transpose(1, 2)

        if query_states.device != key_states.device:
            key_states = key_states.to(query_states.device)

        bsz = bsz * llama_model.task.TOKEN_TYPES

        # input_shape = hidden_states.shape[:-1]
        # hidden_shape = (*input_shape, -1, self.head_dim)

        # query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        # key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        # value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        # logger.debug(position_embeddings[0].shape, query_states.shape)

        assert len(position_embeddings) == 2
        # print(position_embeddings[0].shape, original_bsz, extended_bsz, q_len)
        if position_embeddings[0].shape[0] == original_bsz:
            position_embeddings = (
                position_embeddings[0].expand(llama_model.task.TOKEN_TYPES,
            original_bsz, key_states.shape[-2], -1).reshape(extended_bsz, key_states.shape[-2], -1),
                position_embeddings[1].expand(llama_model.task.TOKEN_TYPES,
            original_bsz, key_states.shape[-2], -1).reshape(extended_bsz, key_states.shape[-2], -1)
            )

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # The following lines are deleted because now we do that in the beginning because of how interventions are designed
        # key_states = repeat_kv(key_states, module.num_key_value_groups)
        # value_states = repeat_kv(value_states, module.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        attention_mask = attention_mask.expand(llama_model.task.TOKEN_TYPES, original_bsz, 1, q_len, key_states.shape[-2]).reshape(extended_bsz, 1, q_len, key_states.shape[-2])

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        # attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        # attn_output = self.o_proj(attn_output)

        head_outputs_together = []
        for cur_head in range(self.num_heads):
            if past_key_value is not None and q_len == 1:
                target_tp = llama_model.task.TARGET_TYPE
                head_outputs_together.append(attn_output[target_tp * original_bsz : (target_tp + 1) * original_bsz, :, cur_head, :].to(self.o_proj.weight.device) @
                         self.o_proj.weight.T[cur_head * self.head_dim : (cur_head + 1) * self.head_dim, :])
            else:
                head_outputs_together.append(torch.empty(original_bsz, q_len, self.o_proj.weight.shape[0],
                                        device=attn_output.device, dtype=attn_output.dtype))
                for tp in range(llama_model.task.TOKEN_TYPES):
                    head_outputs_together[-1][llama_model.tp_inds == tp] = \
                        (attn_output[tp * original_bsz : (tp + 1) * original_bsz, :, cur_head, :].to(self.o_proj.weight.device) @
                         self.o_proj.weight.T[cur_head * self.head_dim : (cur_head + 1) * self.head_dim, :])[llama_model.tp_inds == tp]

        assert len(head_outputs_together) == self.num_heads
        head_outputs_together = sum(head_outputs_together)

        return head_outputs_together, attn_weights, past_key_value

    @staticmethod
    def decoder_forward(
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Any] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        self = None,
        layer = None,
        llama_model = None,
        **kwargs: Any,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        if layer % 10 == 0:
            logger.info(f"Decoder layer {layer}, forward pass")

        residual = hidden_states

        if past_key_value is not None and hidden_states.shape[1] == 1:
            hidden_states_per_head = {head: hidden_states.clone() for head in range(self.self_attn.num_heads)}
            hidden_states_per_head_and_type = {head: {tp: hidden_states_per_head[head].clone() for tp in range(llama_model.task.TOKEN_TYPES)} for head in range(self.self_attn.num_heads)}
        else:
            hidden_states_per_head = {head: hidden_states.clone() for head in range(self.self_attn.num_heads)}
            hidden_states_per_head_and_type = {head: {tp: hidden_states_per_head[head].clone() for tp in range(llama_model.task.TOKEN_TYPES)} for head in range(self.self_attn.num_heads)}
            
            if llama_model.is_corrupted_run:
                for head in range(self.self_attn.num_heads):
                    for tp in range(llama_model.task.TOKEN_TYPES):
                        llama_model.corrupted_activations["input_to_attn_per_type"][layer][head][tp] = hidden_states_per_head_and_type[head][tp].detach().clone().cpu()
            else:
                for (tp, to_tp) in llama_model.ablate_edges_between_types:
                    for head in range(self.self_attn.num_heads):
                        hidden_states_per_head_and_type[head][to_tp][llama_model.tp_inds == tp] = \
                                llama_model.corrupted_activations["input_to_attn_per_type"][layer][head][to_tp][llama_model.tp_inds == tp].to(
                                    hidden_states_per_head_and_type[head][tp].device).detach()

        hidden_states_per_head_and_type = {head: {
                                                    tp: self.input_layernorm(hidden_states_per_head_and_type[head][tp])
                                                    for tp in range(llama_model.task.TOKEN_TYPES)
                                                }
                                                for head in range(self.self_attn.num_heads)
                                            }

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states_per_head_and_type,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)
        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs