import torch
import torch.nn as nn
from models.modules.gmc_module import (AffectGRUEncoder, 
                                       AffectJointProcessor, 
                                       AffectEncoder)
from models.trainers.model_evaluation_metrics import *
from models.losses.nce_loss import NCELoss, IndividualLoss

class SuperGMC(nn.Module):
    def __init__(self, name, common_dim, latent_dim, n_classes=1, dataset='mosi'):
        super(SuperGMC, self).__init__()

        self.name = name
        self.common_dim = common_dim
        self.latent_dim = latent_dim
        self.dataset = dataset
        self.n_classes = n_classes

        self.processors = []

        self.encoder_share = None
        self.encoder_mods = []
        self.reconstructor_mods = []
        self.pre_classifier = None
        self.classifier = None
        self.nce_critertion = NCELoss(dataset=dataset, classification=n_classes > 1)
        self.ind_critertion = IndividualLoss(dataset=dataset)
        if n_classes == 1:
            self.criterion = nn.L1Loss()
        else:
            self.criterion = nn.CrossEntropyLoss()

    def get_reconstruct_parameters(self):
        params = []
        params.extend(self.encoder_mods.parameters())
        params.extend(self.reconstructor_mods.parameters())

        return params

    def encode(self, x, return_reps=False):
        # If we have complete observations
        if None not in x:
            joint_projected = self.processors[-1](x)
            latent = self.encoder_share(joint_projected)
            latent_i = self.encoder_mods[-1](joint_projected)
            # Forward classifier
            output = self.pre_classifier(latent)
            output += latent_i
            if return_reps:
                return latent
            return self.classifier(output)

        else:
            mod_outputs = []
            for id_mod in range(len(x)):
                if x[id_mod] is not None:
                    projected_mod = self.processors[id_mod](x[id_mod])
                    latent = self.encoder_share(projected_mod)
                    latent_i = self.encoder_mods[id_mod](projected_mod)

                    # Forward classifier
                    output = self.pre_classifier(latent)
                    output += latent_i
                    if return_reps:
                        mod_outputs.append(latent)
                    else:
                        mod_outputs.append(self.classifier(output))
            
            fin_output = torch.stack(mod_outputs, dim=0).mean(0)
            return fin_output

    def forward(self, x):
        # Forward pass through the modality specific encoders
        batch_representations = []
        batch_projected_i = []
        batch_reconstructed_i = []
        outputs = []
        for processor_idx in range(len(self.processors)):
            if processor_idx == len(self.processors) - 1:
                projected_mod = self.processors[processor_idx](x)
            else:
                projected_mod = self.processors[processor_idx](x[processor_idx])
            mod_representations = self.encoder_share(projected_mod)
            mod_representations_i = self.encoder_mods[processor_idx](projected_mod)
            mod_reconstructed_i = self.reconstructor_mods[processor_idx](torch.cat((mod_representations, mod_representations_i), dim=1))
            batch_projected_i.append(projected_mod)
            batch_reconstructed_i.append(mod_reconstructed_i)
            batch_representations.append(mod_representations)

            output = self.pre_classifier(mod_representations)
            output += mod_representations_i
            outputs.append(self.classifier(output))

        return outputs, batch_representations, batch_projected_i, batch_reconstructed_i

    def super_gmc_loss(self, predictions, target, target_mask, batch_representations, batch_projected_i, batch_reconstructed_i, temperature, batch_size, epoch):
        joint_nce_loss_sum = 0
        joint_sup_loss_sum = 0
        joint_ind_loss_sum = 0

        prediction = predictions[-1]
        tqdm_dict = eval_mosei(prediction, target, exclude_zero=True, classification=self.n_classes > 1)
        for mod in range(len(batch_representations) - 1):
            loss_joint_mod = self.nce_critertion(batch_representations[-1], batch_representations[mod],
                                                  target, predictions[mod], target_mask, temperature)
            joint_nce_loss_sum += loss_joint_mod
            joint_ind_loss_sum += self.ind_critertion(batch_projected_i[mod], batch_reconstructed_i[mod])
            # supervised loss
            if torch.any(target_mask):
                mod_supervised_loss = self.criterion(predictions[mod][target_mask], target[target_mask])
                joint_sup_loss_sum += mod_supervised_loss
            else:
                joint_sup_loss_sum += 0

        joint_nce_loss_sum += self.nce_critertion(batch_representations[-1], batch_representations[-1],
                                                  target, prediction, target_mask, temperature)
        joint_ind_loss_sum += self.ind_critertion(batch_projected_i[-1], batch_reconstructed_i[-1])
        joint_ind_loss_sum /= len(batch_representations) #- 1
        joint_nce_loss_sum /= len(batch_representations) #- 1
        
        if torch.any(target_mask):
            supervised_loss = self.criterion(prediction[target_mask], target[target_mask])
            joint_sup_loss_sum += supervised_loss
        else:
            supervised_loss = torch.tensor(0.0, device=target.device)
        
        joint_sup_loss_sum += supervised_loss
        joint_sup_loss_sum /= len(batch_representations)

        # print(f'nce: {joint_nce_loss_sum}, sup: {joint_sup_loss_sum}, ind: {joint_ind_loss_sum}')
        loss = torch.mean(joint_sup_loss_sum + joint_nce_loss_sum)
        # loss = torch.mean(joint_sup_loss_sum)

        tqdm_dict["loss"] = loss
        tqdm_dict['rec_loss'] = joint_ind_loss_sum
        return loss, tqdm_dict

    def training_step(self, data, target_data, mask_data, train_params, epoch):
        temperature = train_params.temperature
        batch_size = data[0].shape[0]

        # Forward pass through the encoders
        outputs, batch_representations, batch_projected_i, batch_reconstructed_i = self.forward(data)

        # Compute contrastive + supervised loss
        loss, tqdm_dict = self.super_gmc_loss(outputs, target_data, mask_data, batch_representations, batch_projected_i, batch_reconstructed_i, temperature, batch_size, epoch)

        return loss, tqdm_dict

    def validation_step(self, data, target_data, mask_data, train_params, epoch):
        temperature = train_params.temperature
        batch_size = data[0].shape[0]

        # Forward pass through the encoders
        outputs, batch_representations, batch_projected_i, batch_reconstructed_i = self.forward(data)
        # Compute contrastive loss
        loss, tqdm_dict = self.super_gmc_loss(outputs, target_data, mask_data, batch_representations, batch_projected_i, batch_reconstructed_i, temperature, batch_size, epoch)

        return tqdm_dict


# Affect
class AffectGMC(SuperGMC):
    def __init__(self, name, common_dim, latent_dim, dataset='mosei', n_classes=1, transfer=False):
        super(AffectGMC, self).__init__(name, common_dim, latent_dim, n_classes, dataset)

        if not transfer:
            if dataset == 'mosei':
                self.language_processor = AffectGRUEncoder(input_dim=300, hidden_dim=30, latent_dim=common_dim, common_dim=latent_dim, timestep=50)
                self.audio_processor = AffectGRUEncoder(input_dim=74, hidden_dim=30, latent_dim=common_dim, common_dim=latent_dim, timestep=50)
                self.vision_processor = AffectGRUEncoder(input_dim=35, hidden_dim=30, latent_dim=common_dim, common_dim=latent_dim, timestep=50)
            else:
                self.language_processor = AffectGRUEncoder(input_dim=300, hidden_dim=30, latent_dim=common_dim, common_dim=latent_dim, timestep=50)
                self.audio_processor = AffectGRUEncoder(input_dim=5, hidden_dim=30, latent_dim=common_dim, common_dim=latent_dim, timestep=50)
                self.vision_processor = AffectGRUEncoder(input_dim=20, hidden_dim=30, latent_dim=common_dim, common_dim=latent_dim, timestep=50)
            self.joint_processor = AffectJointProcessor(common_dim, latent_dim, dataset)
        else:
            self.language_processor = AffectGRUEncoder(input_dim=300, hidden_dim=30, latent_dim=common_dim, common_dim=latent_dim, timestep=50)
            self.audio_processor = AffectGRUEncoder(input_dim=5, hidden_dim=30, latent_dim=common_dim, common_dim=latent_dim, timestep=50)
            self.vision_processor = AffectGRUEncoder(input_dim=20, hidden_dim=30, latent_dim=common_dim, common_dim=latent_dim, timestep=50)
            self.joint_processor = AffectJointProcessor(common_dim, latent_dim, 'mosi')

        self.processors = nn.ModuleList([
            self.language_processor,
            self.audio_processor,
            self.vision_processor,
            self.joint_processor])

        self.encoder_share = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.GELU(),
            AffectEncoder(common_dim=latent_dim, latent_dim=latent_dim)
        ) 
        # modality-specific encoders
        self.encoder_mods = nn.ModuleList([AffectEncoder(common_dim=latent_dim, latent_dim=latent_dim) 
                                          for _ in range(4)])
        
        self.reconstructor_mods = nn.ModuleList([nn.Sequential(
                                                nn.Linear(latent_dim*2, common_dim),
                                                nn.GELU(),
                                                nn.Linear(common_dim, common_dim)
                                            ) for _ in range(4)])
        
        # Classifier
        self.pre_classifier = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.GELU(),
            nn.Dropout(0.0),
            nn.Linear(latent_dim, latent_dim),
        )
        
        self.classifier = nn.Linear(latent_dim, n_classes)

