import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet34, ResNet34_Weights
import random
import time
from models.modules.vilt_module import (TextProcessor, ImageProcessor, JointProcessor)
from models.modules.gmc_module import AffectEncoder
from models.trainers.model_evaluation_metrics import *
from models.losses.nce_loss import NCELoss, IndividualLoss

class SuperViLT(nn.Module):
    def __init__(self, name, common_dim, latent_dim, dataset='ucf101'):
        super(SuperViLT, self).__init__()

        self.name = name
        self.common_dim = common_dim
        self.latent_dim = latent_dim
        self.dataset = dataset

        self.text_processor = None
        self.image_processor = None
        self.processors = [
            self.text_processor,
            self.image_processor,
        ]

        self.encoder_share = None
        self.classifier = None
        self.pre_classifier = None

        self.encoder_mods = []
        self.reconstructor_mods = []
        self.nce_critertion = NCELoss(dataset=dataset)
        self.ind_critertion = IndividualLoss(dataset=dataset)
        if self.dataset in ['mmimdb']:
            self.criterion = nn.BCEWithLogitsLoss()
        elif self.dataset in ['food101', 'hatememes']:
            self.criterion = nn.CrossEntropyLoss()
    
    def get_reconstruct_parameters(self):
        params = []
        params.extend(self.encoder_mods.parameters())
        params.extend(self.reconstructor_mods.parameters())

        return params

    def encode(self, x, return_reps=False):
        # If we have complete observations
        if None not in x:
            joint_projected = self.processors[-1](x)
            latent = self.encoder_share(joint_projected)
            latent_i = self.encoder_mods[-1](joint_projected)
            # Forward classifier
            output = self.pre_classifier(latent)
            output += latent_i
            if not return_reps:
                return self.classifier(output), None
            else:
                return self.classifier(output), output
        else:
            for id_mod in range(len(x)):
                if x[id_mod] is not None:
                    projected_mod = self.processors[id_mod](x[id_mod]).squeeze(-1).squeeze(-1)
                    latent_i = self.encoder_mods[id_mod](projected_mod)
                    latent = self.encoder_share(projected_mod)

                    # Forward classifier
                    output = self.pre_classifier(latent)
                    output += latent_i
                    if not return_reps:
                        return self.classifier(output), None
                    else:
                        return self.classifier(output), output

    def forward(self, x):
        # Forward pass through the modality specific encoders
        batch_representations = []
        batch_projected_i = []
        batch_reconstructed_i = []
        outputs = []

        for processor_idx in range(len(self.processors)):
            if processor_idx == len(self.processors) - 1:
                projected_mod = self.processors[processor_idx](x)
            else:
                projected_mod = self.processors[processor_idx](x[processor_idx])
            mod_representations = self.encoder_share(projected_mod)
            mod_representations_i = self.encoder_mods[processor_idx](projected_mod)
            
            mod_reconstructed_i = self.reconstructor_mods[processor_idx](torch.cat((mod_representations, mod_representations_i), dim=1))
            batch_projected_i.append(projected_mod)
            batch_reconstructed_i.append(mod_reconstructed_i)
            batch_representations.append(mod_representations)

            output = self.pre_classifier(mod_representations)
            output += mod_representations_i
            outputs.append(self.classifier(output))

        return outputs, batch_representations, batch_projected_i, batch_reconstructed_i
    
    def super_gmc_loss(self, predictions, target, target_mask, batch_representations, batch_projected_i, batch_reconstructed_i, temperature):
        joint_nce_loss_sum = 0
        joint_sup_loss_sum = 0
        joint_ind_loss_sum = 0
        prediction = predictions[-1]
        if self.dataset == 'mmimdb':
            tqdm_dict = calculate_f1(prediction, target)
        elif self.dataset == 'food101':
            tqdm_dict = calculate_accuracy(prediction, target)
        elif self.dataset == 'hatememes':
            tqdm_dict = calculate_auroc(prediction, target)

        for mod in range(len(batch_representations) - 1):
            loss_joint_mod = self.nce_critertion(batch_representations[-1], batch_representations[mod],
                                                  target, predictions[mod], target_mask, temperature)
            joint_nce_loss_sum += loss_joint_mod
            joint_ind_loss_sum += self.ind_critertion(batch_projected_i[mod], batch_reconstructed_i[mod])
            # supervised loss
            if torch.any(target_mask):
                if self.dataset == 'mmimdb':
                    mod_supervised_loss = self.criterion(predictions[mod][target_mask], target[target_mask])
                elif self.dataset in ['food101', 'hatememes']:
                    mod_supervised_loss = self.criterion(predictions[mod][target_mask], target[target_mask].long())
                joint_sup_loss_sum += mod_supervised_loss
            else:
                joint_sup_loss_sum += 0

        joint_nce_loss_sum += self.nce_critertion(batch_representations[-1], batch_representations[-1],
                                                  target, prediction, target_mask, temperature)
        joint_ind_loss_sum += self.ind_critertion(batch_projected_i[-1], batch_reconstructed_i[-1])
        joint_ind_loss_sum /= len(batch_representations) #- 1
        joint_nce_loss_sum /= len(batch_representations) #- 1
        
        if torch.any(target_mask):
            if self.dataset == 'mmimdb':
                supervised_loss = self.criterion(prediction[target_mask], target[target_mask])
            elif self.dataset in ['food101', 'hatememes']:
                supervised_loss = self.criterion(prediction[target_mask], target[target_mask].long())
            joint_sup_loss_sum += supervised_loss
        else:
            supervised_loss = torch.tensor(0.0, device=target.device)
        
        joint_sup_loss_sum += supervised_loss
        joint_sup_loss_sum /= len(batch_representations)

        loss = torch.mean(joint_sup_loss_sum + joint_nce_loss_sum)

        tqdm_dict["loss"] = loss
        tqdm_dict['rec_loss'] = joint_ind_loss_sum

        return loss, tqdm_dict

    def training_step(self, data, target_data, mask_data, train_params, epoch):
        temperature = train_params.temperature
        # Forward pass through the encoders
        outputs, batch_representations, batch_projected_i, batch_reconstructed_i = self.forward(data)

        # Compute contrastive + supervised loss
        loss, tqdm_dict = self.super_gmc_loss(outputs, target_data, mask_data,
                                              batch_representations, batch_projected_i, batch_reconstructed_i, temperature)

        return loss, tqdm_dict

    def validation_step(self, data, target_data, mask_data, train_params, epoch):
        temperature = train_params.temperature

        # Forward pass through the encoders
        outputs, batch_representations, batch_projected_i, batch_reconstructed_i = self.forward(data)
        # Compute contrastive loss
        loss, tqdm_dict = self.super_gmc_loss(outputs, target_data, mask_data, batch_representations, batch_projected_i, batch_reconstructed_i, temperature)

        return tqdm_dict

# Affect
class ViLT(SuperViLT):
    def __init__(self, name, num_classes, common_dim, latent_dim, dataset='ucf101'):
        super(ViLT, self).__init__(name, common_dim, latent_dim, dataset=dataset)

        if dataset in ['mmimdb', 'food101', 'hatememes']:
            self.text_processor = TextProcessor(common_dim=common_dim, latent_dim=latent_dim)
            self.image_processor = ImageProcessor(common_dim=common_dim, latent_dim=latent_dim)

            self.joint_processor = JointProcessor(common_dim=common_dim, latent_dim=latent_dim)

        self.processors = nn.ModuleList([
            self.text_processor,
            self.image_processor,
            self.joint_processor])

        self.encoder_share = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.GELU(),
            AffectEncoder(common_dim=latent_dim, latent_dim=latent_dim)
        )
        
        # modality-specific encoders
        self.encoder_mods = nn.ModuleList([AffectEncoder(latent_dim, latent_dim) 
                                          for _ in range(3)])
        
        self.reconstructor_mods = nn.ModuleList([nn.Sequential(
                                                nn.Linear(latent_dim*2, latent_dim),
                                                nn.GELU(),
                                                nn.Linear(latent_dim, latent_dim)
                                            ) for _ in range(3)])

        # Classifier
        self.pre_classifier = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.GELU(),
            nn.Dropout(0.0),
            nn.Linear(latent_dim, latent_dim),
        )
        self.classifier = nn.Linear(latent_dim, num_classes)
    
