from torch import nn
import torch
import itertools
from collections import OrderedDict
from typing import Dict, List
# Local imports
from models.base import BaseModel
from losses.infonce import InfoNCE
from losses.mminfonce import MMInfoNCE


class MMSiamese(BaseModel):
    """ Abstract Siamese network (objective function is unspecified)
    Build a DNN model with a 3-layers MLP as projector as in [1].
    It encodes a pair of data and outputs a pair of representations through
    a single encoder (hence `Siamese` model).

    [1] A Simple Framework for Contrastive Learning of Visual Representations, Chen et al., ICML 2020
    """

    def __init__(self,
                 encoder: nn.Module,
                 projection: nn.Module,
                 optim_kwargs: Dict):
        """
        Args:
            encoder: Multi-modal fusion encoder (e.g. MMFusion)
            projection: MLP projector to the latent space
            optim_kwargs: Optimization hyper-parameters
        """
        super(MMSiamese, self).__init__(optim_kwargs)

        # create the encoder
        self.encoder = encoder

        # build a 3-layers projector
        self.head = projection


    @staticmethod
    def _build_mlp(in_dim, mlp_dim, out_dim):
        return nn.Sequential(OrderedDict([
            ("layer1", nn.Linear(in_dim, mlp_dim)),
            ("bn1", nn.SyncBatchNorm(mlp_dim)),
            ("relu1", nn.ReLU(inplace=True)),
            ("layer2", nn.Linear(mlp_dim, mlp_dim)),
            ("bn2", nn.SyncBatchNorm(mlp_dim)),
            ("relu2", nn.ReLU(inplace=True)),
            ("layer3", nn.Linear(mlp_dim, out_dim)),
        ]))

    def extract_features(self, loader: torch.utils.data.DataLoader, **kwargs):
        """
           Extract global average pooled visual features.
           Args:
               loader: Dataset loader to serve ``(X, y)`` tuples.
               kwargs: given to `encoder.forward()`
           Returns: Pair (Z,y) corresponding to extracted features and corresponding labels
        """
        X, y = [], []
        for X_, y_ in loader:
            if isinstance(X_, torch.Tensor): # needs to cast it as list of one modality
                X_ = [X_]
            X_ = [x.to(self.device) if isinstance(x, torch.Tensor) else x for x in X_]
            y_ = y_.to(self.device)
            with torch.inference_mode():
                # compute output
                output = self.encoder(X_, **kwargs)
                X.extend(output.view(len(output), -1).detach().cpu())
                y.extend(y_.detach().cpu())
        torch.cuda.empty_cache()
        return torch.stack(X, dim=0).to(self.device), torch.stack(y, dim=0).to(self.device)

    def forward(self, x1: List[torch.Tensor], x2: List[torch.Tensor]):
        """
        Input:
            x1, x2: List of tensors (one for each modality)
        Output:
            z1, z2: latent representations
        """

        # compute features for all modalities
        z1 = self.encoder(x1)
        z2 = self.encoder(x2)
        z1 = self.head(z1)
        z2 = self.head(z2)
        return {'aug1_embed': z1,
                'aug2_embed': z2}


class MMSimCLR(MMSiamese):
    def __init__(self,
                 encoder: nn.Module,
                 projection: nn.Module,
                 optim_kwargs: Dict,
                 loss_kwargs: Dict):
        """
        Args:
            encoder: Multi-modal encoder (e.g. MMFusion)
            projection: MLP projector to the latent space
            optim_kwargs: Optimization hyper-parameters
            loss_kwargs: Hyper-parameters for the InfoNCE loss.
        """
        super().__init__(encoder, projection, optim_kwargs)
        self.loss = InfoNCE(**loss_kwargs)


class CoMM(MMSiamese):
    """
        Model interaction between X1, X2 and X through I(Z1, Z), I(Z2, Z) and I(Z, Z')
    """
    def __init__(self,
                 encoder: nn.Module,
                 projection: nn.Module,
                 optim_kwargs: Dict,
                 loss_kwargs: Dict):
        """
        Args:
            encoder: Multi-modal encoder (e.g. MMFusion)
            projection: MLP projector to the latent space
            optim_kwargs: Optimization hyper-parameters
            loss_kwargs: Hyper-parameters for the Decoupled Uniformity loss.
        """
        super().__init__(encoder, projection, optim_kwargs)
        self.loss = MMInfoNCE(**loss_kwargs)

    def forward(self, x1: List[torch.Tensor], x2: List[torch.Tensor]):
        # compute features for all modalities
        all_masks = self.gen_all_possible_masks(len(x1))
        z1 = self.encoder(x1, mask_modalities=all_masks)
        z2 = self.encoder(x2, mask_modalities=all_masks)
        z1 = [self.head(z) for z in z1]
        z2 = [self.head(z) for z in z2]
        return {'aug1_embed': z1,
                'aug2_embed': z2,
                "prototype": -1}

    def gen_all_possible_masks(self, n_mod: int):
        """
        :param n_mod: int
        :return: a list of `n_mod` + 1 boolean masks [Mi] such that all but one bool are False.
            A last bool mask is added where all bool are True
        Examples:
        *   For n_mod==2:
            masks == [[True, False], [False, True], [True, True]]
        *   For n_mod == 3:
            masks == [[True, False, False], [False, True, False], [False, False, True], [True, True, True]]
        """
        masks = []
        for L in range(n_mod):
            mask = [s == L for s in range(n_mod)]
            masks.append(mask)
        masks.append([True for _ in range(n_mod)])
        return masks
