import os
from typing import Dict, Any, Optional, Tuple
import functools
import math
import warnings
from pathlib import Path
import json

import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.nn.functional as F
from torch import nn
from transformers.models.gemma2.modeling_gemma2 import apply_rotary_pos_emb, repeat_kv

from patching_gemma import logger
from patching_gemma.models.gemma2 import Gemma2Model
from patching_gemma.models.utils.data_processing.collate_function import collate_fn_everything
from patching_gemma.models.utils.data_processing.dataset import RequestDataset

class Gemma2FindFVHeads(Gemma2Model):
    def run(self, task, limit, batch_size, log_dir) -> None:
        self.task = task
        assert task.can_be_token_separable
        self.cumulative = True

        dataset = RequestDataset(task, limit, corrupted=True, tokenizer=self.tokenizer)
        self.model_logs["first_3_dataset_examples"] = [dataset[i] for i in range(3)]
        self.num_requests = len(dataset)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                                collate_fn=functools.partial(collate_fn_everything,
                                                             padding_left=False, use_corrupted_activations=True,
                                                             tokenizer=self.tokenizer),
                                num_workers=len(os.sched_getaffinity(0)) - 1)

        self.loss_after_ablation = {
            "no_ablation": 0,
            "attn": torch.zeros(len(self.model.model.layers),
                                self.model.model.layers[0].self_attn.num_heads,
                                self.task.TOKEN_TYPES),
        }
        self.not_corrupted_activations = {
                layer: {
                    head : None
                    for head in range(self.model.model.layers[0].self_attn.num_heads)
                }
                for layer in range(len(self.model.model.layers))
        }

        self.get_fv_scores(dataloader)
                
        Path(log_dir).joinpath("loss_after_ablation").mkdir(exist_ok=True, parents=True)
        for tp in self.loss_after_ablation:
            with open(str(Path(log_dir).joinpath("loss_after_ablation").joinpath(f"{tp}.npy")), "wb") as file:
                np.save(file, torch.tensor(self.loss_after_ablation[tp]).cpu().numpy())

    def get_fv_scores(self, dataloader) -> None:
        examples = []
        first_3_targets = []

        num_predictive_inds_total = 0

        for batch_i, batch in enumerate(dataloader):
            logger.info(f"Starting run for new batch {batch_i}/{len(dataloader)}")
            inputs = batch[5].to("cuda") # TODO: change to handle different model device
            target_ids = batch[7].to("cuda") # TODO: change to handle different model device
            lens = batch[2]
            if len(examples) < 3:
                for i in range(inputs["input_ids"].shape[0]):
                    if len(examples) < 3:
                        examples.append(self.tokenizer.decode(inputs["input_ids"][i].detach().cpu()))
                        first_3_targets.append(self.tokenizer.decode(target_ids[i]))

            corrupted_contexts = batch[6].to("cuda") # TODO: change to handle different model device

            # TODO: averaging only over non-padding tokens (need to think how to do it in omp scores compute too)
            # For token types computation
            self.tp_inds = self.task.get_token_types_for_contexts_with_targets(self.tokenizer, inputs["input_ids"].detach().cpu())
            self.corrupted_tp_inds = self.task.get_token_types_for_contexts_with_targets(self.tokenizer, corrupted_contexts["input_ids"].detach().cpu())
            assert (self.corrupted_tp_inds == self.tp_inds).all()

            predictive_inds = ((self.tp_inds == self.task.TARGET_TYPE) | (self.tp_inds == self.task.LAST_SEP_TYPE))
            predictive_inds[torch.arange(predictive_inds.shape[0]), (predictive_inds.cumsum(dim=-1) * predictive_inds).argmax(dim=-1)] = 0
            num_predictive_inds_total += torch.count_nonzero(predictive_inds).item()

            self.is_corrupted_run = False
            with torch.no_grad():
                self.model(**inputs)

            self.is_corrupted_run = True
            logger.info(f"Starting ablation")
            self.what_to_ablate = None
            with torch.no_grad():
                logits = self.model(**corrupted_contexts).logits
                loss = self.task.loss_function(logits, target_ids, lens, self.tp_inds)
                self.loss_after_ablation["no_ablation"] += loss.item() * torch.count_nonzero(predictive_inds).item()
            if "no_ablation" not in self.model_logs:
                self.model_logs["no_ablation"] = {}
                self.model_logs["no_ablation"]["loss"] = []
            self.model_logs["no_ablation"]["loss"].append(loss.item())

            for layer in range(len(self.model.model.layers)):
                logger.info(f"Starting ablation for layer {layer}")
                for head in range(self.model.model.layers[0].self_attn.num_heads):
                    for tp in range(self.task.TOKEN_TYPES):
                        if tp != self.task.LAST_SEP_TYPE:
                            continue
                        self.what_to_ablate = (layer, head, tp)
                        with torch.no_grad():
                            logits = self.model(**corrupted_contexts).logits
                            loss = self.task.loss_function(logits, target_ids, lens, self.tp_inds)
                            self.loss_after_ablation["attn"][self.what_to_ablate[0]][self.what_to_ablate[1]][self.what_to_ablate[2]] += loss.item() * torch.count_nonzero(predictive_inds).item()
                        if str(self.what_to_ablate) not in self.model_logs:
                            self.model_logs[str(self.what_to_ablate)] = {}
                            self.model_logs[str(self.what_to_ablate)]["loss"] = []
                        self.model_logs[str(self.what_to_ablate)]["loss"].append(loss.item())

            del inputs
            torch.cuda.empty_cache()

        self.model_logs["first_3_targets"] = first_3_targets
        self.model_logs["first_3_loader_generate_exampels"] = examples

        self.loss_after_ablation["no_ablation"] /= num_predictive_inds_total
        self.loss_after_ablation["attn"] /= num_predictive_inds_total

    def break_into(self) -> None:
        self.hook_handles = []
        self.prev_forwards = []

        for layer in range(len(self.model.model.layers)):
            self.prev_forwards.append(self.model.model.layers[layer].self_attn.forward)
            forward_partial = functools.partial(self.attn_forward, layer=layer,
                                                self=self.model.model.layers[layer].self_attn,
                                                llama_model=self)
            self.model.model.layers[layer].self_attn.forward = forward_partial

    def break_out(self) -> None:
        for layer, f in enumerate(self.prev_forwards):
            forward_partial = functools.partial(self.prev_forwards[layer],
                                                self=self.model.model.layers[layer].self_attn)
            self.model.model.layers[layer].self_attn.forward = forward_partial
        for h in self.hook_handles:
            h.remove()

    @staticmethod
    def attn_forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Any] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        layer: Optional[int] = None,
        llama_model: Optional[Any] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {
                "sin": sin,
                "cos": cos,
                "sliding_window": self.sliding_window,
                "cache_position": cache_position,
            }
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling

        if self.config.attn_logit_softcapping is not None:
            attn_weights = attn_weights / self.config.attn_logit_softcapping
            attn_weights = torch.tanh(attn_weights)
            attn_weights = attn_weights * self.config.attn_logit_softcapping
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        # HERE WE CHANGE CODE
        # FROM THIS
        # attn_output = attn_output.view(bsz, q_len, -1)
        # attn_output = self.o_proj(attn_output)
        # TO THIS

        head_outputs_together = []
        for cur_head in range(self.num_heads):
            head_outputs_together.append(attn_output[:, :, cur_head, :] @ self.o_proj.weight.T[cur_head * self.head_dim : (cur_head + 1) * self.head_dim, :])
            if past_key_value is not None and q_len == 1:
                # Skipping ablation in generate mode split by token type, in second pass, because of cache
                pass
            elif not llama_model.is_corrupted_run:
                llama_model.not_corrupted_activations[layer][cur_head] = head_outputs_together[-1].detach().cpu()
            else:
                if llama_model.cumulative:
                    if llama_model.what_to_ablate is not None and (layer < llama_model.what_to_ablate[0] or (llama_model.what_to_ablate[0] == layer and cur_head <= llama_model.what_to_ablate[1])):
                        tp_to_ablate = llama_model.what_to_ablate[2]
                        head_outputs_together[-1][llama_model.tp_inds == tp_to_ablate] = \
                            llama_model.not_corrupted_activations[layer][cur_head][llama_model.tp_inds == tp_to_ablate].to(head_outputs_together[-1].dtype).to(head_outputs_together[-1].device)
                else:
                    if llama_model.what_to_ablate is not None and llama_model.what_to_ablate[0] == layer and llama_model.what_to_ablate[1] == cur_head:
                        tp_to_ablate = llama_model.what_to_ablate[2]
                        head_outputs_together[-1][llama_model.tp_inds == tp_to_ablate] = \
                            llama_model.not_corrupted_activations[layer][cur_head][llama_model.tp_inds == tp_to_ablate].to(head_outputs_together[-1].dtype).to(head_outputs_together[-1].device)
        
        assert len(head_outputs_together) == self.num_heads
        head_outputs_together = sum(head_outputs_together)

        # END OF CHANGED CODE

        if not output_attentions:
            attn_weights = None

        return head_outputs_together, attn_weights, past_key_value