import math

import torch

from train.gradient_based import maml_inner_adapt
from data.shapenet1d import degree_loss
from evals import accuracy
from utils import MetricLogger

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def check(P):
    filename_with_today_date = True
    return filename_with_today_date


def test_classifier(P, model, loader, criterion, steps, logger=None, attack_module=None):
    metric_logger = MetricLogger(delimiter="  ")

    if logger is None:
        log_ = print
    else:
        log_ = logger.write_log_nohead
    
    # Switch to evaluate mode
    mode = model.training
    model.eval()
    acc = 0.
    ema_flag = False
    if P.ema or (not (P.load_path is None) and 'ema' in P.load_path):
        ema_flag =True
        if hasattr(P, 'moving_inner_lr'):
            inner_step_ema = P.moving_inner_lr
        else:
            inner_step_ema = P.inner_lr

    for n, batch in enumerate(loader):

        if n * P.test_batch_size > P.max_test_task:
            break

        
        train_inputs, train_targets = batch['train']
        if P.dataset == 'miniimagenet':
            try:
                train_inputs = train_inputs.to(device, non_blocking=True)
            except:
                train_inputs = train_inputs[0].to(device, non_blocking=True)

        else:
            train_inputs = train_inputs.to(device, non_blocking=True)
        train_targets = train_targets.to(device, non_blocking=True)

        test_inputs, test_targets = batch['test']
        if P.dataset == 'miniimagenet':
            try:
                test_inputs = test_inputs.to(device, non_blocking=True)
            except:
                test_inputs = test_inputs[0].to(device, non_blocking=True)
        else:
            test_inputs = test_inputs.to(device, non_blocking=True)
        test_targets = test_targets.to(device, non_blocking=True)

        for task_idx, (train_input, train_target, test_input, test_target) \
                in enumerate(zip(train_inputs, train_targets, test_inputs, test_targets)):

            params, loss_train = maml_inner_adapt(
                model, criterion, train_input, train_target, P.inner_lr, P.inner_steps_test, first_order=True, inner_update_type=P.inner_update_type,
            )

            if ema_flag:
                params_ema, loss_train_ema = maml_inner_adapt(
                model, criterion, train_input, train_target, inner_step_ema, P.inner_steps_test,
                first_order=True, params=P.moving_average, inner_update_type=P.inner_update_type
                )

            """ outer loss aggregate """
            
                
            with torch.no_grad():
                outputs_test = model(test_input, params=params, inner_update_type=P.inner_update_type)
                if ema_flag:
                    outputs_test_ema = model(test_input, params=params_ema, inner_update_type=P.inner_update_type)

            loss = criterion(outputs_test, test_target)
            if ema_flag:
                loss_ema = criterion(outputs_test_ema, test_target)

            if not P.regression:
                acc = accuracy(outputs_test, test_target, topk=(1, ))[0].item()
                if ema_flag:
                    acc_ema = accuracy(outputs_test_ema, test_target, topk=(1, ))[0].item()
            elif P.dataset == 'shapenet':
                acc = - degree_loss(outputs_test, test_target).item()
            elif P.dataset == 'pose':
                acc = - loss.item()
            else:
                raise NotImplementedError()

            metric_logger.meters['loss_train'].update(loss_train.item())
            metric_logger.meters['loss'].update(loss.item())
            metric_logger.meters['acc'].update(acc)
            if ema_flag:
                metric_logger.meters['ema_loss_train'].update(loss_train_ema.item())
                metric_logger.meters['ema_loss'].update(loss_ema.item())
                metric_logger.meters['ema_acc'].update(acc_ema)

            if not (attack_module is None):
                import copy
                deep_copy_model = copy.deepcopy(model)
                test_adv_input = attack_module.perturb(deep_copy_model, test_input, test_target, params=params)
                with torch.no_grad():
                    outputs_adv_test = model(test_adv_input, params=params, inner_update_type=P.inner_update_type)
                
                adv_loss = criterion(outputs_adv_test, test_target)

                advacc = accuracy(outputs_adv_test, test_target, topk=(1, ))[0].item()

                metric_logger.meters['adv_loss'].update(adv_loss.item())
                metric_logger.meters['advacc'].update(advacc)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Dataset: ", P.dataset)
    print(' * [Acc@1 %.3f] [LossOut %.3f] [LossIn %.3f]' %
         (metric_logger.acc.global_avg, metric_logger.loss.global_avg,
          metric_logger.loss_train.global_avg))
    
    if ema_flag:
        print(' * EMA [Acc@1 %.3f] [LossOut %.3f] [LossIn %.3f]' %
         (metric_logger.ema_acc.global_avg, metric_logger.ema_loss.global_avg,
          metric_logger.ema_loss_train.global_avg))

    if not (attack_module is None):
        print(' * Robust [Acc@1 %.3f] [LossOut %.3f]' %
         (metric_logger.advacc.global_avg, metric_logger.adv_loss.global_avg))
        if not (P.load_path is None):
            log_txt.write(' * [Acc@1 %.3f] [LossOut %.3f] [LossIn %.3f]\n' %
         (metric_logger.acc.global_avg, metric_logger.loss.global_avg,
          metric_logger.loss_train.global_avg))
            log_txt.write(' * Robust [Acc@1 %.3f] [LossOut %.3f]\n' %
         (metric_logger.advacc.global_avg, metric_logger.adv_loss.global_avg))

    if logger is not None:
        logger.write_log_nohead({
                'eval/acc': metric_logger.acc.global_avg,
                'eval/loss_test': metric_logger.loss.global_avg,
                'eval/loss_train': metric_logger.loss_train.global_avg
                }, step=steps)
        if not (attack_module is None):
            logger.write_log_nohead({
                'eval/advacc': metric_logger.advacc.global_avg,
                'eval/adv_loss': metric_logger.adv_loss.global_avg,
                }, step=steps)
            
        if ema_flag:
            logger.write_log_nohead({
                'eval/ema_acc': metric_logger.ema_acc.global_avg,
                'eval/ema_loss_test': metric_logger.ema_loss.global_avg,
                'eval/ema_loss_train': metric_logger.ema_loss_train.global_avg
                }, step=steps)
            
    model.train(mode)
    if not (attack_module is None):
        return metric_logger.acc.global_avg, metric_logger.advacc.global_avg

    return metric_logger.acc.global_avg
