import math
from abc import ABC, abstractmethod

import torch
from torch import nn
from torch.nn import functional as F

import math
from models.autoregressive_models import (
    MyEmbed,
    EnhancedLinesModel,
)


def dict_to_cpu(dictionary):
    cpu_dict = {}
    for key, value in dictionary.items():
        if isinstance(value, torch.Tensor):
            cpu_dict[key] = value.cpu()
        elif isinstance(value, dict):
            cpu_dict[key] = dict_to_cpu(value)
        else:
            cpu_dict[key] = value
    return cpu_dict


class AbstractNetwork(ABC, torch.nn.Module):
    def __init__(self):
        # super().__init__()
        pass

    @abstractmethod
    def initial_inference(self, observation):
        pass

    @abstractmethod
    def recurrent_inference(self, encoded_state, action):
        pass

    def get_weights(self):
        return dict_to_cpu(self.state_dict())

    def set_weights(self, weights):
        self.load_state_dict(weights)


class DynamicsMLPNetwork(nn.Module):
    def __init__(
        self,
        input_size=256,
        hidden_size=512,
        output_size=256,
        num_layers=2,
        bias=True,
        full_support_size=601,
    ):
        super().__init__()

        # use lstm instead like in https://arxiv.org/pdf/2109.00527.pdf
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bias=bias,
            batch_first=True,
        )

        self.prediction_head = nn.Linear(hidden_size, output_size)
        self.reward_head = nn.Linear(hidden_size, full_support_size)

    def forward(self, x):
        # state = self.net(x)
        state = self.lstm(x.unsqueeze(1))[0].squeeze(1)
        next_state = self.prediction_head(state)
        reward = self.reward_head(state)
        return next_state, reward


class MuZeroLinesModel:
    def __new__(cls, network, *args, **kwargs):
        if network == "enhanced-keypoints-tokens":
            return EnhancedLinesModelRL(*args, **kwargs)
        else:
            raise NotImplementedError


cosine_similarity_loss_fn = torch.nn.CosineSimilarity(dim=1)


class EnhancedLinesModelRL(AbstractNetwork, EnhancedLinesModel):
    def __init__(
        self,
        max_step=40,
        support_size=250,
        checkpoint=None,
        support_scaling_factor=None,
        add_predictor=False,  # used for consistency loss
        cls_token=False,
        *args,
        **kwargs
    ):
        AbstractNetwork.__init__(self)
        EnhancedLinesModel.__init__(self, *args, **kwargs)

        self.full_support_size = support_size * 2 + 1
        self.support_scaling_factor = support_scaling_factor

        self.first_choice_new_edge_embed_surrogate = nn.Parameter(
            torch.randn((self.embedding_dim * 2,)) * 1 / (self.embedding_dim * 2),
            requires_grad=True,
        )

        self.line_length_embed = MyEmbed(
            math.ceil(self.image_size * math.sqrt(2)), self.embedding_dim
        )
        self.keypoints_type_embed = MyEmbed(self.max_seq_length, self.embedding_dim)

        self.max_step = max_step

        self.step_percentage_embed = nn.Linear(1, self.embedding_dim)

        self.value_prediction_net = nn.Sequential(
            nn.Linear(
                self.encoder_config["hidden_size"] + self.embedding_dim,
                self.embedding_dim,
            ),
            nn.GELU(),
            nn.Linear(self.embedding_dim, self.embedding_dim),
            nn.GELU(),
            nn.Linear(self.embedding_dim, self.embedding_dim),
            nn.GELU(),
            nn.Linear(self.embedding_dim, self.full_support_size),
        )

        self.encode_vector_space = nn.Sequential(
            nn.Linear(self.encoder_config["hidden_size"], self.embedding_dim),
            nn.GELU(),
            nn.Linear(self.embedding_dim, self.embedding_dim),
            nn.GELU(),
            nn.Linear(self.embedding_dim, self.embedding_dim),
            nn.GELU(),
            nn.Linear(self.embedding_dim, self.embedding_dim),
        )

        # takes last output of decoder and the encoder embedding based on the action selected
        self.dynamics_network = DynamicsMLPNetwork(
            input_size=self.embedding_dim * 3 + self.decoder_config["hidden_size"],
            hidden_size=self.embedding_dim * 2,
            output_size=self.decoder_config["hidden_size"],
            num_layers=3,  # num_layers=2
            bias=True,
            full_support_size=self.full_support_size,
        )

        if add_predictor:
            dim = self.decoder_config["hidden_size"]
            pred_dim = self.decoder_config["hidden_size"] // 2
            # build a 2-layer predictor
            self.predictor = nn.Sequential(
                nn.Linear(dim, pred_dim, bias=False),
                nn.BatchNorm1d(pred_dim),
                nn.ReLU(inplace=True),  # hidden layer
                nn.Linear(pred_dim, dim),
            )  # output layer

        self.cls_token = cls_token

        if self.cls_token:
            # Instead of using a global image embeddin guse this to initialize the decoding
            self.cls_token_emb = nn.Parameter(
                torch.randn(
                    self.embedding_dim,
                )
                * 1
                / (self.embedding_dim * 3)
            )

        if checkpoint is not None:
            loaded_checkpoint = torch.load(checkpoint, map_location="cpu")
            if "weights" in loaded_checkpoint:
                loaded_checkpoint = loaded_checkpoint["weights"]

            try:
                self.load_state_dict(loaded_checkpoint, strict=False)
            except RuntimeError as e:
                print("RuntimeError while loading from state dict.")
                print(e)
                print("Retrying by loading only slow trained parameters..")

                # This is mainly in case we are trying to load a pre-trained autoregressive model
                new_loaded_checkpoint = {
                    k: v
                    for k, v in loaded_checkpoint.items()
                    if not k.startswith(self.fast_trained_parameters_names())
                }
                try:
                    self.load_state_dict(new_loaded_checkpoint, strict=False)
                except:
                    print("Invalid model architecture ...")
                    print("Skipping loading entirely..")

    def get_parameters(self):
        return [
            {"params": self.slow_trained_parameters(), "type": "slow"},
            {"params": self.fast_trained_parameters(), "type": "fast"},
        ]

    def fast_trained_parameters_names(self):
        return tuple(
            [
                "first_choice_new_edge_embed_surrogate",
                "line_length_embed",
                "keypoints_type_embed",
                "step_percentage_embed",
                "project_to_logits",
                "value_prediction_net",
                "encode_vector_space",
                "dynamics_network",
            ]
            + ["add_predictor"]
            if hasattr(self, "add_predictor")
            else [] + ["cls_token"]
            if hasattr(self, "cls_token")
            else []
        )

    def fast_trained_parameters(self):
        return sum(
            list(
                map(
                    lambda x: list(x.parameters())
                    if issubclass(type(x), torch.nn.Module)
                    else [x],
                    map(
                        lambda x: x[1],
                        filter(
                            lambda k: k[0].startswith(
                                self.fast_trained_parameters_names()
                            ),
                            self.named_parameters(),
                        ),
                    ),
                )
            ),
            [],
        )

    def slow_trained_parameters(self):
        return sum(
            list(
                map(
                    lambda x: list(x.parameters())
                    if issubclass(type(x), torch.nn.Module)
                    else [x],
                    map(
                        lambda x: x[1],
                        filter(
                            lambda k: not k[0].startswith(
                                self.fast_trained_parameters_names()
                            ),
                            self.named_parameters(),
                        ),
                    ),
                )
            ),
            [],
        )

    def prediction(self, encoded_state):
        # encoded state has state, keypoints_embeddings, keypoints_masks, current_step_percentage, next_action_type
        step_percentage = encoded_state["step_percentage"].view(-1, 1).float()
        step_percentage_embed = self.step_percentage_embed(step_percentage)
        value = self.value_prediction_net(
            torch.cat([encoded_state["state"], step_percentage_embed], 1)
        )

        pred_pointers = self.project_to_logits(encoded_state["state"])

        logits = torch.matmul(
            pred_pointers.unsqueeze(1),
            encoded_state["keypoints_embeddings"].permute(0, 2, 1),
        )

        logits /= math.sqrt(float(self.embedding_dim))
        f_verts_mask = F.pad(encoded_state["keypoints_mask"], pad=(1, 0), value=1)[
            :, None
        ]

        policy_logits = logits * f_verts_mask
        policy_logits -= (1.0 - f_verts_mask) * 1e9
        # always only one element in this case
        policy_logits = policy_logits[:, 0]

        return policy_logits, value

    def get_constant_features(self, observation):
        keypoints_mask = torch.ones(
            observation["keypoints"].shape[:2], device=observation["keypoints"].device
        )

        for i in range(len(observation["keypoints_len"])):
            keypoints_mask[i, observation["keypoints_len"][i] :] = 0

        (
            keypoints_embeddings,
            cross_attention_features,
            global_context_embedding,
        ) = self._prepare_context(
            observation["image"], observation["keypoints"], keypoints_mask
        )

        if self.cls_token:
            # replace global context embedding with the cls token ...
            for i in range(global_context_embedding.shape[0]):
                global_context_embedding[i] = self.cls_token_emb

        return (
            keypoints_embeddings,
            cross_attention_features,
            global_context_embedding,
            keypoints_mask,
        )

    def representation(self, observation, global_features=None):
        if global_features is None:
            (
                keypoints_embeddings,
                cross_attention_features,
                global_context_embedding,
                keypoints_mask,
            ) = self.get_constant_features(observation)
        else:
            (
                keypoints_embeddings,
                cross_attention_features,
                global_context_embedding,
                keypoints_mask,
            ) = global_features

        lines = observation["current_edges"].long()
        current_edges_len = observation["current_edges_len"]

        lines_masks = torch.masked_fill(
            torch.zeros_like(lines),
            torch.arange(lines.shape[1])[None, :]
            .repeat(lines.shape[0], 1)
            .to(current_edges_len.device)
            < current_edges_len[:, None],
            1,
        )

        decoder_inputs = self._embed_inputs(
            lines,
            lines_masks,
            keypoints_embeddings,
            global_context_embedding,
        )

        decoder_outputs = self.decoder(
            decoder_inputs,
            sequential_context_embeddings=[cross_attention_features],
        )

        latent_space = decoder_outputs[
            torch.arange(decoder_outputs.shape[0]),
            torch.zeros(decoder_outputs.shape[0]).long(),
        ]

        return {
            "state": latent_space,
            "step_percentage": observation["step_percentage"],
            "keypoints": observation["keypoints"],
            "keypoints_embeddings": keypoints_embeddings,
            "keypoints_mask": keypoints_mask,
            "lines": lines,
            "current_edges_len": current_edges_len,
            "next_step_is_first_keypoints": current_edges_len % 2 == 0,
        }

    def dynamics(self, encoded_state, action, pool=None):
        batch_size = action.shape[0]
        device = encoded_state["keypoints_embeddings"].device

        chosen_vertices = self.encode_vector_space(
            encoded_state["keypoints_embeddings"][
                torch.arange(action.shape[0]), action.view(-1)
            ]
        )

        new_edge_embed = torch.empty(
            (action.shape[0], self.embedding_dim * 2),
            device=device,
        )
        first_choices = torch.logical_or(
            encoded_state["next_step_is_first_keypoints"],
            encoded_state["current_edges_len"] == 0,
        )

        # when first keypoints
        new_edge_embed[first_choices] = self.first_choice_new_edge_embed_surrogate

        second_choices = torch.logical_not(first_choices)
        if torch.sum(second_choices) > 0:
            # when second keypoints, need to calculate edge characteristics
            first_keypoints = (
                encoded_state["lines"][
                    torch.arange(batch_size)[second_choices],
                    encoded_state["current_edges_len"][second_choices] - 1,
                ]
                - 1
            )
            second_keypoints = action.view(-1)[second_choices] - 1

            edges_first_keypoints = torch.zeros(batch_size, device=device).long()
            edges_second_keypoints = torch.zeros(batch_size, device=device).long()

            for j in range(encoded_state["lines"].shape[1]):
                edges_first_keypoints[second_choices] = edges_first_keypoints[
                    second_choices
                ] + (
                    first_keypoints == (encoded_state["lines"][second_choices, j] - 1)
                ) * (
                    j < (encoded_state["current_edges_len"][second_choices] - 1)
                )
                edges_second_keypoints[second_choices] = edges_second_keypoints[
                    second_choices
                ] + (
                    second_keypoints == (encoded_state["lines"][second_choices, j] - 1)
                ) * (
                    j < (encoded_state["current_edges_len"][second_choices] - 1)
                )

            starts = encoded_state["keypoints"][
                torch.arange(batch_size)[second_choices], first_keypoints
            ]
            ends = encoded_state["keypoints"][
                torch.arange(batch_size)[second_choices], second_keypoints
            ]

            edge_sizes = self.line_length_embed(
                torch.round(torch.norm(ends - starts, dim=-1)).long()
            )
            keypoints_types = self.keypoints_type_embed(
                torch.stack(
                    [
                        edges_first_keypoints[second_choices],
                        edges_second_keypoints[second_choices],
                    ],
                    dim=-1,
                )
            ).sum(1)

            # points = (
            #     starts.unsqueeze(1)
            #     + torch.linspace(0, 1, self.embedding_dim, device=device)[None, :, None]
            #     * (ends - starts).unsqueeze(1)
            # ).long()

            new_edge_embed[second_choices] = torch.cat(
                [edge_sizes, keypoints_types], dim=1
            )

        # potentially also add the angle of the newly generated edge??

        augmented_space = torch.cat(
            [encoded_state["state"], chosen_vertices, new_edge_embed], dim=-1
        )
        next_encoded_state_val, reward = self.dynamics_network(augmented_space)

        next_encoded_state = {
            "state": next_encoded_state_val,
            "step_percentage": encoded_state["step_percentage"] + 1 / self.max_step,
            "keypoints": encoded_state["keypoints"],
            "keypoints_embeddings": encoded_state["keypoints_embeddings"],
            "keypoints_mask": encoded_state["keypoints_mask"],
            "lines": torch.cat([encoded_state["lines"], action], dim=1),
            "current_edges_len": encoded_state["current_edges_len"] + 1,
            "next_step_is_first_keypoints": torch.logical_not(
                encoded_state["next_step_is_first_keypoints"]
            ),
        }

        return next_encoded_state, reward

    def initial_inference(self, observation, global_features=None):
        encoded_state = self.representation(
            observation, global_features=global_features
        )
        policy_logits, value = self.prediction(encoded_state)

        reward = torch.log(
            (
                torch.zeros(1, self.full_support_size)
                .scatter(1, torch.tensor([[self.full_support_size // 2]]).long(), 1.0)
                .repeat(value.shape[0], 1)
                .to(value.device)
            )
        )

        return value, reward, policy_logits, encoded_state

    def recurrent_inference(self, encoded_state, action):
        next_encoded_state, reward = self.dynamics(encoded_state, action)
        policy_logits, value = self.prediction(next_encoded_state)
        return value, reward, policy_logits, next_encoded_state

    def forward(
        self,
        observation_batch,
        action_batch,
        observation_at_td_steps_batch=None,
        use_consistency_loss=False,
    ):
        # training done in forward for easier multi-gpu training
        value, reward, policy_logits, hidden_state = self.initial_inference(
            observation_batch
        )

        predictions = [(value, reward, policy_logits)]

        for i in range(1, action_batch.shape[1]):
            value, reward, policy_logits, hidden_state = self.recurrent_inference(
                hidden_state, action_batch[:, i]
            )

            # Scale the gradient at the start of the dynamics function (See paper appendix Training)
            for k, _ in hidden_state.items():
                if hidden_state[k].requires_grad:
                    hidden_state[k].register_hook(lambda grad: grad * 0.5)

            predictions.append((value, reward, policy_logits))

        if use_consistency_loss:
            _, _, _, last_true_hidden_state = self.initial_inference(
                observation_at_td_steps_batch
            )

            consistency_loss = -cosine_similarity_loss_fn(
                last_true_hidden_state["state"].detach(), hidden_state["state"]
            )
        else:
            consistency_loss = torch.tensor(0).to(hidden_state["state"].device)

        return (
            predictions,
            consistency_loss,
            torch.tensor(action_batch.shape[0]).to(hidden_state["state"].device),
        )


######### End LinesModel #########
##################################


### Replece scaling with a simple linear scaling for the beggining ...
def calculate_support_scaling_factor(support_size, max_reward):
    return support_size / max_reward


def support_to_scalar(logits, support_size, scaling_factor):
    """
    Transform a categorical representation to a scalar
    See paper appendix Network Architecture
    """
    # Decode to a scalar
    probabilities = torch.softmax(logits, dim=1)
    support = (
        torch.tensor([x for x in range(-support_size, support_size + 1)])
        .expand(probabilities.shape)
        .float()
        .to(device=probabilities.device)
    )
    x = torch.sum(support * probabilities, dim=1, keepdim=True)

    # Invert the scaling (defined in https://arxiv.org/abs/1805.11593)
    return x / scaling_factor


def scalar_to_support(x, support_size, scaling_factor):
    """
    Transform a scalar to a categorical representation with (2 * support_size + 1) categories
    See paper appendix Network Architecture
    """
    # Reduce the scale (defined in https://arxiv.org/abs/1805.11593)
    x = scaling_factor * x

    # Encode on a vector
    x = torch.clamp(x, -support_size, support_size)
    floor = x.floor()
    prob = x - floor
    logits = torch.zeros(x.shape[0], x.shape[1], 2 * support_size + 1).to(x.device)
    logits.scatter_(
        2, (floor + support_size).long().unsqueeze(-1), (1 - prob).unsqueeze(-1)
    )
    indexes = floor + support_size + 1
    prob = prob.masked_fill_(2 * support_size < indexes, 0.0)
    indexes = indexes.masked_fill_(2 * support_size < indexes, 0.0)
    logits.scatter_(2, indexes.long().unsqueeze(-1), prob.unsqueeze(-1))
    return logits
