from collections import OrderedDict
from dataclasses import asdict, dataclass
from typing import Any, Dict, List

import fsspec
import mlflow
import numpy
import torch
from rl.utils import get_causal_mask
from tensordict import TensorDict

from torch.optim.optimizer import Optimizer


class TrainCollate(torch.nn.Module):
    def __init__(
        self,
        device,
        dtype=torch.float32,
        embedding_size=64,
        train_value=False,
        const_node=True,
        return_action_mask=True,
        get_causal_mask=False,
        negation_prob=0.5,
        permutation_prob=0.5,
    ):
        super().__init__()
        self.device = device
        self.dtype = dtype
        self.embedding_size = embedding_size
        self.const_node = const_node
        self.return_action_mask = return_action_mask
        self.get_causal_mask = get_causal_mask
        self.train_value = train_value
        # self.transform_probs = transform_probs
        self.negation_prob = negation_prob
        self.permutation_prob = permutation_prob
        self.keys = ["nodes", "actions"]
        if return_action_mask:
            self.keys.append("action_mask")
        if get_causal_mask:
            self.keys.append("causal_mask")
        if train_value:
            self.keys.append("reward")
        self.add_target_negation = False

    def __call__(self, batch):
        # move data to deive
        batch = self.batch_data(batch)
        if self.device is not None and self.device.type == "cuda":
            # batch = batch.pin_memory().to(self.device, non_blocking=True)
            batch = batch.contiguous().pin_memory()
        if not self.get_causal_mask:
            batch["causal_mask"] = None  # type: ignore
        return batch

    def batch_data(self, batch):
        return torch.stack([self.prepare_sample(td) for td in batch])

    def prepare_sample(self, batch):
        slice, td = batch

        # bring the data to RAM
        td = td.clone(True).contiguous()

        # Prepare the data
        num_inputs = torch.sym_int(td["num_inputs"])
        num_outputs = torch.sym_int(td["num_outputs"])
        cur_nodes = slice - num_inputs  # type: ignore

        if self.const_node:
            cur_nodes -= 1
        else:
            td["nodes"] = td["nodes"][1:, :]
            td["left_node"].add_(-1)
            td["right_node"].add_(-1)

        # Add the target negation
        if self.add_target_negation:
            td["target"] = torch.cat([td["target"], ~td["target"]], dim=0)
            num_outputs *= 2
        # Apply negation transformation
        elif numpy.random.rand() < self.negation_prob:
            td["target"] = ~td["target"]

        # Slice the data
        # td["nodes"] = torch.cat((td["nodes"][:slice], td["target"]), dim=0)
        td["nodes"] = torch.cat((td["target"], td["nodes"][:slice]), dim=0)

        td["actions"] = td["actions"][:, :slice, :slice].to(self.dtype)
        e = td["edge_type"]
        l = td["left_node"]
        r = td["right_node"]

        # Set the actions already performed to 0
        td["actions"][e[:cur_nodes], l[:cur_nodes], r[:cur_nodes]] = 0

        if self.train_value:
            td["reward"] = td["reward"][cur_nodes]  # type: ignore

        if self.return_action_mask:
            td["action_mask"] = self.create_action_mask(slice)
            td["action_mask"][e[:cur_nodes], l[:cur_nodes], r[:cur_nodes]] = True

        # Create the causal mask in case we want to pass it explicitly to the model
        if self.get_causal_mask:
            td["causal_mask"] = self.create_causal_mask(
                num_inputs,  # type: ignore
                num_outputs,  # type: ignore
                td["nodes"].size(-2),
            )

        # Apply the transform to the AIG
        if numpy.random.rand() < self.negation_prob:
            self.apply_rand_perm(td)

        td["nodes"] = td["nodes"].to(self.dtype)

        return td.select(*self.keys)

    def set_device(self, device):
        self.device = device

    def create_causal_mask(self, num_inputs: int, num_ouputs: int, num_nodes: int):
        diag = num_inputs + num_ouputs
        if not self.const_node:
            diag -= 1
        mask = torch.ones((num_nodes, num_nodes), dtype=torch.bool).tril_(diag)[
            None, :, :
        ]
        return mask

    def create_action_mask(self, num_nodes):
        action_mask = (
            torch.triu(
                torch.ones((num_nodes, num_nodes), dtype=torch.bool),
                diagonal=0,
            ).T
        ).repeat(4, 1, 1)
        return action_mask

    def apply_rand_perm(self, td: TensorDict) -> None:
        rand_perm = torch.randperm(td["nodes"].size(-1))
        td["nodes"] = td["nodes"][:, rand_perm]


@torch.jit.script
def kldiv_activation(action_logits: torch.Tensor):
    # action = torch.flatten(action_logits, start_dim=1)
    # action = torch.nn.functional.log_softmax(
    #     action_logits.flatten(1) + torch.finfo(action_logits.dtype).eps, dim=-1
    # )
    # return action
    return action_logits.flatten(1).log_softmax(-1)


# @torch.jit.script
def normalize_action(action_ints: torch.Tensor):
    action = torch.flatten(action_ints, start_dim=1)
    # action = (action + torch.finfo(action.dtype).eps)
    action /= torch.sum(action, dim=-1, keepdim=True)
    return action
    # tgt = torch.flatten(tgt, start_dim=1)
    # tgt = (tgt + torch.finfo(tgt.dtype).eps)
    # tgt = tgt.softmax(-1)


@torch.jit.script
def tanh_activation(value: torch.Tensor):
    return value.tanh()


@torch.jit.script
def tanh_activation2(value: torch.Tensor):
    return (value.tanh() + 1) / 2


@torch.jit.script
def base_activation(tensor: torch.Tensor):
    return torch.flatten(tensor, start_dim=1)


def load_model(model, checkpoint):
    model_state = model.state_dict()
    for name, param in checkpoint.state_dict():
        if name not in model.state:
            continue
        if isinstance(param, torch.nn.Parameter):
            param = param.data
        model_state[name].copy_(param)


def get_mask(
    hidden_states: torch.Tensor,
    attention_mask: torch.Tensor | None,
    is_causal: bool,
    n_pis: int | List[int] = 0,
) -> torch.Tensor | None:
    mask = None
    if is_causal:
        if isinstance(n_pis, list):
            mask = torch.stack(
                [
                    torch.ones(
                        (hidden_states.shape[1], hidden_states.shape[1]),
                        dtype=torch.bool,
                        device=hidden_states.device,
                        requires_grad=False,
                    ).tril(diagonal=pi)
                    for pi in n_pis
                ]
            )[:, None, :, :]
        else:
            causal_mask = torch.ones(
                (hidden_states.shape[1], hidden_states.shape[1]),
                dtype=torch.bool,
                device=hidden_states.device,
                requires_grad=False,
            ).tril(
                diagonal=n_pis
            )  # diagonal inputs has to n_inputs
            mask = causal_mask[None, None, :, :].repeat(hidden_states.shape[0], 1, 1, 1)
        mask[:, :, :, -1] = True
        # causal_mask = torch.full(
        #             (hidden_states.shape[1], hidden_states.shape[1]),
        #             fill_value=torch.finfo(hidden_states.dtype).min,
        #             device=hidden_states.device,
        #             requires_grad=False,
        #         ).triu(diagonal=1) # diagonal inputs has to n_inputs + 1

    if attention_mask is not None:
        mask = (
            attention_mask[:, None, None, :]
            if mask is None
            else mask & attention_mask[:, None, None, :].to(torch.bool)
        )
    if mask is not None:
        mask = torch.zeros_like(
            mask, dtype=torch.float32, device=hidden_states.device
        ).masked_fill(~mask, torch.finfo(torch.float32).min)
    return mask


def combine_masks(
    causal_mask: torch.Tensor | None, attention_mask: torch.Tensor | None
) -> torch.Tensor | None:
    mask = causal_mask
    if attention_mask is not None:
        mask = (
            attention_mask[:, None, None, :]
            if mask is None
            else mask & attention_mask[:, None, None, :].to(torch.bool)
        )
    # if mask is not None:
    #     mask = torch.zeros_like(
    #         mask,
    #         dtype=torch.float32,
    #     ).masked_fill(~mask, torch.finfo(torch.float32).min)
    return mask


@dataclass
class Snapshot:
    model_state: "OrderedDict[str, torch.Tensor]"
    optimizer_state: Dict[str, Any]
    finished_epoch: int


def save_snapshot(
    model: torch.nn.Module,
    optimizer: Optimizer,
    epoch: int,
    path: str,
) -> None:
    # Unwrap model if it is a DistributedDataParallel
    raw_model = model.module if hasattr(model, "module") else model

    snapshot = Snapshot(
        model_state=raw_model.state_dict(),  # type: ignore
        optimizer_state=optimizer.state_dict(),  # type: ignore
        finished_epoch=epoch,
    )
    snapshot = asdict(snapshot)

    torch.save(snapshot, path)


def load_snapshot(model_path: str) -> Snapshot:
    snapshot = fsspec.open(model_path)
    with snapshot as f:
        snapshot_data = torch.load(f, map_location="cpu")  # type: ignore
    return Snapshot(**snapshot_data)  # type: ignore


def mlflow_save_snapshot(
    model: torch.nn.Module,
    optimizer: Optimizer,
    epoch: int,
    path: str,
) -> None:
    # Unwrap model if it is a DistributedDataParallel
    raw_model = model.module if hasattr(model, "module") else model

    snapshot = Snapshot(
        model_state=raw_model.state_dict(),  # type: ignore
        optimizer_state=optimizer.state_dict(),  # type: ignore
        finished_epoch=epoch,
    )
    snapshot = asdict(snapshot)
    torch.save(snapshot, path)
    mlflow.log_artifact(path)
    # mlflow.pytorch.save_model(raw_model, dir)


def mlflow_load_snapshot(model_path: str) -> Snapshot:
    snapshot = fsspec.open(model_path)
    with snapshot as f:
        snapshot_data = mlflow.pytorch.load_model(f)  # type: ignore
    return Snapshot(**snapshot_data)  # type: ignore


def prepare_kvcache_generation(input_embeds, cache, mask=None):
    seq_len = cache["0"]["keys"].shape[-3]
    trimmed_embeds = input_embeds[:, seq_len:, :]
    if mask is None:
        return trimmed_embeds
    print(mask.shape)
    # if seq_len == 8:
    #     return trimmed_embeds, mask[:, :, -1, :]
    mask = mask[:, :, seq_len:, :]
    print(mask)
    return trimmed_embeds, mask


# (bsz, ..., q_len, kv_len)


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class KVCache(TensorDict):
    def __init__(
        self,
        num_layers: int,
        batch_size: int,
        num_key_value_heads: int,
        num_heads: int,
        head_dim: int,
        device=torch.device("cpu"),
    ):
        super().__init__(
            {},
            batch_size=[batch_size],
            device=device,
        )
        if num_layers is not None:
            for i in range(num_layers):
                self.add_layer(str(i), num_key_value_heads, num_heads, head_dim)

    def __contains__(self, key):
        return key in self

    def add_layer(
        self, layer_id: str, num_key_value_heads: int, num_heads: int, head_dim: int
    ):
        self[layer_id] = LayerKVCache(
            layer_id,
            self.batch_size[0],
            num_key_value_heads,
            num_heads,
            head_dim,
            self.device,
        )

    @classmethod
    def initialize_from_model(cls, model, batch_size=1, device=torch.device("cpu")):
        total_layers = model.n_layers + 4 * model.n_policy_layers + model.n_value_layers
        return cls(
            total_layers,
            batch_size,
            model.num_key_value_heads,
            model.n_heads,
            model.head_dim,
            device,
        )

    def set_target(self, target):
        self["target_keys"] = target["nodes"]
        self["target_values"] = target["reward"]


class LayerKVCache(TensorDict):
    def __init__(
        self,
        layer_id: str,
        batch_size: int,
        num_key_value_heads: int,
        num_heads: int,
        head_dim: int,
        device: torch.device | None,
    ):
        super().__init__(
            {
                "keys": torch.empty(batch_size, 0, num_key_value_heads, head_dim),
                "values": torch.empty(batch_size, 0, num_key_value_heads, head_dim),
                "target_queries": torch.empty(batch_size, 0, num_heads, head_dim),
                "queries": torch.empty(batch_size, 0, num_heads, head_dim),
                # "attention_weights": torch.empty(batch_size, 0),
            },
            batch_size=[batch_size],
            device=device,
        )
        self.layer_id = layer_id

    def __len__(self) -> int:
        return self["keys"].shape[-3]

    def update_kvcache(self, new_key, new_value):
        self["keys"] = torch.cat((self["keys"], new_key), dim=-3)
        self["values"] = torch.cat((self["values"], new_value), dim=-3)
        return self["keys"], self["values"]

    def update_keys(self, new_key):
        self["keys"] = torch.cat((self["keys"], new_key), dim=-3)
        return self["keys"]

    def update_queries(self, new_query):
        self["queries"] = torch.cat((self["queries"], new_query), dim=-3)
        return self["queries"]

    def update_target_queries(self, new_query):
        self["target_queries"] = torch.cat((self["target_queries"], new_query), dim=-3)
        return self["target_queries"]

    def get_queries(self):
        return self["target_queries"]

    def update_attention_weights(self, attention_weights):
        self["attention_weights"] = torch.nn.functional.pad(
            self["attention_weights"], (0, 1, 0, 1)
        )
        self["attention_weights"][:, :, -1] = attention_weights
        return self["attention_weights"]

    def set_target(
        self, target_keys: torch.Tensor, target_values: torch.Tensor
    ) -> None:
        self["target_keys"] = target_keys
        self["target_values"] = target_values

    @property
    def keys(self) -> torch.Tensor:
        return self["keys"]

    @keys.setter
    def keys(self, new_key: torch.Tensor):
        self["keys"] = new_key

    @property
    def values(self) -> torch.Tensor:
        return self["values"]

    @values.setter
    def values(self, new_value: torch.Tensor):
        self["values"] = new_value

    @property
    def queries(self) -> torch.Tensor:
        return self["queries"]

    @queries.setter
    def queries(self, new_query: torch.Tensor):
        self["queries"] = new_query

    @property
    def target_queries(self) -> torch.Tensor:
        return self["target_queries"]

    @target_queries.setter
    def target_queries(self, new_query: torch.Tensor):
        self["target_queries"] = new_query
