import torch
import models.auxiliary.scheduler as sc
import copy
from sklearn.metrics import f1_score 
from tqdm import tqdm
import os
from IPython import embed
from models.search.darts.utils import count_parameters, save, save_pickle


def train_mmimdb_track_f1(  model, architect,
                            criterion, optimizer, scheduler, dataloaders,
                            dataset_sizes, device, num_epochs, 
                            parallel, logger, plotter, args,
                            f1_type='weighted', init_f1=0.0, th_fscore=0.3, 
                            status='search'):
    pass


def test_mmimdb_track_f1(  model, criterion, dataloaders,
                           dataset_sizes, device, 
                           parallel, logger, args,
                           f1_type='weighted', init_f1=0.0, th_fscore=0.3):

    best_test_genotype = None
    best_test_f1 = init_f1
    best_test_epoch = 0
    
    model.eval()  # Set model to evaluate mode
    list_preds = [] 
    list_label = []                    

    running_loss = 0.0
    running_f1 = init_f1
    phase = 'test'

    with tqdm(dataloaders[phase]) as t:
        # Iterate over data.
        for data in dataloaders[phase]:
            # get the inputs
            image, text, label = data['image'], data['text'], data['label']
            # device
            image = image.to(device)
            text = text.to(device)                
            label = label.to(device)

            output = model((text, image))        
            if isinstance(output, tuple):
                output = output[-1]

            _, preds = torch.max(output, 1)
            loss = criterion(output, label)
            preds_th = torch.sigmoid(output) > th_fscore
            # if phase == 'dev':
            list_preds.append(preds_th.cpu())
            list_label.append(label.cpu()) 

            # statistics
            running_loss += loss.item() * image.size(0)

            batch_pred_th = preds_th.data.cpu().numpy()
            batch_true = label.data.cpu().numpy()
            batch_f1 = f1_score(batch_pred_th, batch_true, average='samples')  

            postfix_str = 'batch_loss: {:.03f}, batch_f1: {:.03f}'.format(loss.item(), batch_f1)
            t.set_postfix_str(postfix_str)
            t.update()
                
    epoch_loss = running_loss / dataset_sizes[phase]
    
    # if phase == 'dev':
    y_pred = torch.cat(list_preds, dim=0).numpy()
    y_true = torch.cat(list_label, dim=0).numpy()

    epoch_f1 = f1_score(y_true, y_pred, average=f1_type)                  

    logger.info('{} Loss: {:.4f}, {} F1: {:.4f}'.format(
                    phase, epoch_loss, f1_type, epoch_f1))
    
    if parallel:
        num_params = 0
        for reshape_layer in model.module.reshape_layers:
            num_params += count_parameters(reshape_layer)

        num_params += count_parameters(model.module.fusion_net)
        logger.info("Fusion Model Params: {}".format(num_params) )
        genotype = model.module.genotype()
    else:
        num_params = 0
        for reshape_layer in model.reshape_layers:
            num_params += count_parameters(reshape_layer)

        num_params += count_parameters(model.fusion_net)
        logger.info("Fusion Model Params: {}".format(num_params) )
        genotype = model.genotype()
    logger.info(str(genotype))
    best_test_f1 = epoch_f1
    return best_test_f1
