from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .tasks._eval_protocols import *
from itertools import combinations
from collections import defaultdict
import json
import os

from models.trainers.model_evaluation_metrics import *
from util_scripts.wandb_logger import WandbLogger
from util_scripts.train_callbacks import ModelSaverLoaderCallback


class ModelEvaluation(nn.Module):
    def __init__(self, model, dataset, test_loader, opt, modalities=None, last_cpkt=False):
        super(ModelEvaluation, self).__init__()

        self.dataset = dataset

        self.test_modalities = modalities
        self.test_loader = test_loader
        # device
        self.device = opt.device
        # logger
        self.logger = WandbLogger(opt)
        self.classification = opt.classification
        # callback
        self.callback = ModelSaverLoaderCallback(opt.result_path, 'model', opt=opt)
        self.model = self.callback.load_cpkt(model, last=last_cpkt)
        self.labeled_ratio = opt.labeled_ratio
        self.save_reps = opt.save_reps
        self.model.eval()

        # missing modalities combinations
        modes = []
        modalities = list(range(self.test_modalities))
        for i in range(1, len(modalities) + 1):
            modes.extend(list(combinations(modalities, i)))
        
        self.modes = modes
        
    def evaluate(self):
        with torch.no_grad():
            for mods in self.modes:
                print('Evaluating modalities: ', mods)   
                if self.dataset in ['mosei', 'mosi']:
                    results = []
                    truths = []
                    for i_batch, (batch_X, batch_Y, batch_META) in tqdm(enumerate(self.test_loader)):
                        sample_ind, text, audio, vision = batch_X
                        data = [text.to(self.device), audio.to(self.device), vision.to(self.device)]
                        target_data, _ = batch_Y
                        target_data = target_data.squeeze(-1).to(self.device)  # if num of labels is 1

                        # Drop modalities (if required)
                        input_data = []
                        
                        for j in range(len(data)):
                            if j not in mods:
                                input_data.append(None)
                            else:
                                input_data.append(data[j])

                        # Parallel model
                        if self.labeled_ratio != 0.0:
                            preds = self.model.encode(input_data)
                        # Collect the results into dictionary
                        truths.append(target_data)
                        results.append(preds)

                    results = torch.cat(results)
                    truths = torch.cat(truths)

                    eval_mosei(results, truths, self.logger, True, self.classification)

                elif self.dataset in ['mmimdb', 'food101', 'hatememes']:
                    results = []
                    truths = []
                    reps = []
                    for i_batch, (batch_X, batch_Y) in tqdm(enumerate(self.test_loader)):
                        image, text = batch_X
                        data = [text.to(self.device), image.to(self.device)]
                        target_data, _ = batch_Y
                        target_data = target_data.float().squeeze(-1).to(self.device)  # if num of labels is 1

                        # Drop modalities (if required)
                        input_data = []
                        
                        for j in range(len(data)):
                            if j not in mods:
                                input_data.append(None)
                            else:
                                input_data.append(data[j])

                        # Parallel model
                        if self.labeled_ratio != 0.0:
                            preds, rep = self.model.encode(input_data, return_reps=self.save_reps)

                        # Collect the results into dictionary
                        truths.append(target_data)
                        results.append(preds)
                        reps.append(rep)

                    results = torch.cat(results)
                    truths = torch.cat(truths)
                    if self.dataset == "mmimdb":
                        calculate_f1(results, truths, self.logger)
                    elif self.dataset == 'food101':
                        calculate_accuracy(results, truths, self.logger)
                    elif self.dataset == 'hatememes':
                        calculate_auroc(results, truths, self.logger)

                    if self.save_reps:
                        os.makedirs(f'logs/{self.dataset}/eval/reps_{self.labeled_ratio}/', exist_ok=True)
                        print('Saving representations...')
                        reps = torch.cat(reps)
                        np.save(f'logs/{self.dataset}/eval/reps_{self.labeled_ratio}/reps_{k}.npy', reps.cpu().detach().numpy())
                        np.save(f'logs/{self.dataset}/eval/reps_{self.labeled_ratio}/truths_{k}.npy', truths.cpu().detach().numpy())