from __future__ import annotations
from abc import abstractmethod
import numpy
import torch
from tensordict import TensorDict, TensorClass, tensorclass


# Define get_causal_mask here
def get_causal_mask(
        tot_nodes: int, # total number of nodes
        num_inputs: int,
        num_outputs=1,
        const_node=True,
        device=torch.device("cpu"),
    ) -> torch.Tensor:
    diag = num_inputs + num_outputs
    if not const_node:
        diag -= 1
    causal_mask = torch.ones(
        (tot_nodes, tot_nodes),
        dtype=torch.bool,
        device=device,
        requires_grad=False,
    ).tril(
        diagonal=diag  # type: ignore
    )  # diagonal inputs has to n_inputs
    # mask = causal_mask[None, None, :, :].expand(observation.shape[0], -1, -1, -1)
    return causal_mask[None, None, :, :]

@torch.jit.script
def kldiv_activation(action_logits: torch.Tensor):
    return action_logits.flatten(1).log_softmax(-1)


# @torch.jit.script
def normalize_action(action_ints: torch.Tensor):
    action = torch.flatten(action_ints, start_dim=1)
# action = (action + torch.finfo(action.dtype).eps)
    action /= torch.sum(action, dim=-1, keepdim=True)
    return action
# tgt = torch.flatten(tgt, start_dim=1)
    # tgt = (tgt + torch.finfo(tgt.dtype).eps)
    # tgt = tgt.softmax(-1)


@torch.jit.script
def tanh_activation(value: torch.Tensor):
    return value.tanh()


@torch.jit.script
def tanh_activation2(value: torch.Tensor):
    return (value.tanh() + 1) / 2


@torch.jit.script
def base_activation(tensor: torch.Tensor):
    return torch.flatten(tensor, start_dim=1)


def load_model(model, checkpoint):
    model_state = model.state_dict()
    for name, param in checkpoint.state_dict():
        if name not in model.state:
            continue
        if isinstance(param, torch.nn.Parameter):
            param = param.data
        model_state[name].copy_(param)


def get_mask(
    hidden_states: torch.Tensor,
    attention_mask: torch.Tensor | None,
    is_causal: bool,
    n_pis: int | list[int] = 0,
) -> torch.Tensor | None:
    mask = None
    if is_causal:
        if isinstance(n_pis, list):
            mask = torch.stack(
                [
                    torch.ones(
                        (hidden_states.shape[1], hidden_states.shape[1]),
                        dtype=torch.bool,
                        device=hidden_states.device,
                        requires_grad=False,
                    ).tril(diagonal=pi)
                    for pi in n_pis
                ]
            )[:, None, :, :]
        else:
            causal_mask = torch.ones(
                (hidden_states.shape[1], hidden_states.shape[1]),
                dtype=torch.bool,
                device=hidden_states.device,
                requires_grad=False,
            ).tril(
                diagonal=n_pis
            )  # diagonal inputs has to n_inputs
            mask = causal_mask[None, None, :, :].repeat(hidden_states.shape[0], 1, 1, 1)
        mask[:, :, :, -1] = True

    if attention_mask is not None:
        mask = (
            attention_mask[:, None, None, :]
            if mask is None
            else mask & attention_mask[:, None, None, :].to(torch.bool)
        )
    if mask is not None:
        mask = torch.zeros_like(
            mask, dtype=torch.float32, device=hidden_states.device
        ).masked_fill(~mask, torch.finfo(torch.float32).min)
    return mask


def combine_masks(
    causal_mask: torch.Tensor | None, attention_mask: torch.Tensor | None
) -> torch.Tensor | None:
    mask = causal_mask
    if attention_mask is not None:
        mask = (
            attention_mask[:, None, None, :]
            if mask is None
            else mask & attention_mask[:, None, None, :].to(torch.bool)
        )
    return mask


def prepare_kvcache_generation(input_embeds, cache, mask=None):
    seq_len = cache["0"]["keys"].shape[-3]
    trimmed_embeds = input_embeds[:, seq_len:, :]
    if mask is None:
        return trimmed_embeds
    print(mask.shape)
    mask = mask[:, :, seq_len:, :]
    print(mask)
    return trimmed_embeds, mask


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class KVCache(TensorDict):
    def __init__(
        self,
        batch_size: int,
        device=torch.device("cpu"),
    ):
        super().__init__(
            {},
            batch_size=[batch_size],
            device=device,
        )

    def __contains__(self, key):
        return key in self

    def add_layer(
        self, layer_id: str, layer: "TransformerLayerKVCache" | "ActionHeadKVCache",
    ):
        self[layer_id] = layer

    def set_target(self, target):
        self["target_keys"] = target["nodes"]
        self["target_values"] = target["reward"]

    def get_layer_cahce(self, layer_id: str) -> "LayerKVCache":
        return self[layer_id]


class LayerKVCache(TensorClass):
    @abstractmethod
    def __init__(self):
        pass

    @abstractmethod
    def update_kvcache(self, new_key, new_value):
        pass


class TransformerLayerKVCache(LayerKVCache):
    key_cache: torch.Tensor  # Renamed from keys
    value_cache: torch.Tensor # Renamed from values

    def __len__(self) -> int:
        return self.key_cache.shape[-3] # Use renamed attribute

    def update_kvcache(self, new_key, new_value):
        self.key_cache = torch.cat((self.key_cache, new_key), dim=-3) # Use renamed attribute
        self.value_cache = torch.cat((self.value_cache, new_value), dim=-3) # Use renamed attribute
        return self.key_cache, self.value_cache # Return renamed attributes

    def update_keys(self, new_key):
        self.key_cache = torch.cat((self.key_cache, new_key), dim=-3) # Use renamed attribute in dict access
        return self.key_cache # Return renamed attribute

    def set_target(
        self, target_keys: torch.Tensor, target_values: torch.Tensor
    ) -> None:
        # Assuming target_keys/values are meant for different storage, keep as is for now
        # If they were meant to interact with the renamed attributes, this needs adjustment
        self.key_cache = target_keys
        self.target_values = target_values

    @property
    def keys(self) -> torch.Tensor: # Keep property name for external API? Or rename? Let's rename.
        return self.key_cache

    @keys.setter # Rename setter accordingly
    def keys(self, new_key: torch.Tensor):
        self.key_cache = new_key

    @property
    def values(self) -> torch.Tensor: # Rename property
        return self.value_cache

    @values.setter # Rename setter accordingly
    def values(self, new_value: torch.Tensor):
        self.value_cache = new_value


class ActionHeadKVCache(TensorClass):
    attention_weights: torch.Tensor

    def update_attention_weights(self, attention_weights):
        self.attention_weights = torch.nn.functional.pad(
            self.attention_weights, (0, 1, 0, 1)
        )
        self.attention_weights[:, :, -1] = attention_weights
        return self.attention_weights