from typing import List
import torch

from common.constants import CLASSIFICATION, TRANSFORMER_FEEDFORWARD_SIZE, TRANSFORMER_HEADS, TRANSFORMER_LAYER_NUM
from tabular_diffusion.denoising_models.columnar_embedding_for_graph import ColumnarEmbeddingForGraph
from tabular_diffusion.denoising_models.time_embedding import TimeEmbedding


class CompleteTransformerDenoisingModel(torch.nn.Module):

    def __init__(self,
                 num_cont: int,
                 num_classes: List[int],
                 hidden_size: int=128,
                 timesteps: int=1000,
                 params: dict = None,
                 problem_type: str = CLASSIFICATION,
                 with_target=True):
        super(CompleteTransformerDenoisingModel, self).__init__()
        self.with_target = with_target
        self.num_cont = num_cont
        self.problem_type = problem_type
        self.num_classes = num_classes

        self.time_embedding = TimeEmbedding(dim=hidden_size, num_steps=timesteps)

        self.target_index = -1 if problem_type == CLASSIFICATION else 0
        self.target_dtype = torch.long if problem_type == CLASSIFICATION else torch.float
        self.columnar_embedding = ColumnarEmbeddingForGraph(con_features_num=num_cont,
                                                            cat_features_num=len(num_classes),
                                                            cat_features_degrees=num_classes,
                                                            latent_space_size=hidden_size,
                                                            null_in_categorical_embedding=False,
                                                            global_cls_num=0)

        if self.num_cont > 0:
            self.transformer_model_cont = torch.nn.Transformer(d_model=hidden_size,
                                                               nhead=params[TRANSFORMER_HEADS],
                                                               num_encoder_layers=params[TRANSFORMER_LAYER_NUM],
                                                               num_decoder_layers=params[TRANSFORMER_LAYER_NUM],
                                                               dim_feedforward=params[TRANSFORMER_FEEDFORWARD_SIZE],
                                                               batch_first=True)
        if len(self.num_classes) > 0:
            self.transformer_model_cat = torch.nn.Transformer(d_model=hidden_size,
                                                              nhead=params[TRANSFORMER_HEADS],
                                                              num_encoder_layers=params[TRANSFORMER_LAYER_NUM],
                                                              num_decoder_layers=params[TRANSFORMER_LAYER_NUM],
                                                              dim_feedforward=params[TRANSFORMER_FEEDFORWARD_SIZE],
                                                              batch_first=True)
        self.norm_layer = torch.nn.LayerNorm([num_cont + len(num_classes), hidden_size])

        self.outputs = torch.nn.ModuleList()
        for _ in range(num_cont):
            self.outputs.append(torch.nn.Sequential(torch.nn.Linear(hidden_size, hidden_size),
                                                    torch.nn.LayerNorm(hidden_size),
                                                    torch.nn.ReLU(),
                                                    torch.nn.Linear(hidden_size, hidden_size),
                                                    torch.nn.LayerNorm(hidden_size),
                                                    torch.nn.ReLU(),
                                                    torch.nn.Linear(hidden_size, 1)))
        for n in num_classes:
            self.outputs.append(torch.nn.Sequential(torch.nn.Linear(hidden_size, hidden_size),
                                                    torch.nn.LayerNorm(hidden_size),
                                                    torch.nn.ReLU(),
                                                    torch.nn.Linear(hidden_size, hidden_size),
                                                    torch.nn.LayerNorm(hidden_size),
                                                    torch.nn.ReLU(),
                                                    torch.nn.Linear(hidden_size, n)))

    def forward(self,
                x: torch.tensor,
                t: torch.tensor,
                condition: torch.tensor = None,
                mask: torch.tensor = None) -> torch.tensor:
        x = self.columnar_embedding(x[:, :self.num_cont],
                                    x[:, self.num_cont:].to(torch.long))
        x_cont = None
        if self.num_cont > 0:
            x_cont = self.transformer_model_cont(src=x,
                                                 tgt=x,
                                                 src_key_padding_mask=(mask < 1).to(torch.bool),
                                                 tgt_key_padding_mask=(mask > 0).to(torch.bool))
        x_cat = None
        if len(self.num_classes):
            x_cat = self.transformer_model_cat(src=x,
                                       tgt=x,
                                       src_key_padding_mask=(mask < 1).to(torch.bool),
                                       tgt_key_padding_mask=(mask > 0).to(torch.bool)
                                       )
        if (x_cont is not None) and (x_cat is not None):
            x = torch.cat([x[:, :self.num_cont, :], x[:, self.num_cont:, :]], dim=1)
        elif x_cat is None:
            x = x_cont
        else:
            x = x_cat
        x = x + self.time_embedding(t).unsqueeze(dim=1)
        x = self.norm_layer(x)
        tmp = []
        for i, l in enumerate(self.outputs):
            tmp.append(l(x[:, i, :]))
        return torch.cat(tmp, dim=1)

    def save_checkpoint_model(self,
                              metric: float,
                              epoch: int,
                              checkpoint_path: str, ):
        """Save current model state

        :param metric: float, metric value
        :param epoch: int, current epoch
        :param checkpoint_path: str, path where the model has to be saved
        :return: None
        """
        if checkpoint_path is not None:
            torch.save({
                'epoch': epoch,
                'model_state_dict': self.state_dict(),
                'metric': metric
            }, checkpoint_path)

    def load_checkpoint_model(self, checkpoint_path: str, device: str):
        """ Load the model state from checkpoint path

        :param checkpoint_path: str, path wher the modelo state in stored
        :param device: str
        :return: None
        """
        if device == 'cuda':
            checkpoint = torch.load(checkpoint_path, map_location="cuda:0")
        else:
            checkpoint = torch.load(checkpoint_path)
        self.load_state_dict(checkpoint['model_state_dict'])  # Choose whatever GPU device number you want
        self.to(device)
