import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from train import copy_model_param, param_ema
from train.gradient_based import maml_inner_adapt
from evals import accuracy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from utils import jsd_loss

def check(P):
    filename_with_today_date = True
    assert P.num_shots_global == 0
    return filename_with_today_date

def trades_criterion(input_clean, input_adv, target, advw=1.0):
    return F.cross_entropy(input_clean, target) + advw * F.kl_div(F.log_softmax(input_adv, dim=1), F.softmax(input_clean, dim=1)).mean()

def maml_selfsup_step(P, step, model, criterion, optimizer, batch, metric_logger, logger, attack_module):

    stime = time.time()
    model.train()

    inner_loss = 0.
    acc = 0.
    outer_loss = torch.tensor(0., device=device)
    class_loss_total = 0.
    num_tasks = batch['test'][1].size(0)

    if P.ema:
        if not hasattr(P, 'moving_average'):
            P.moving_average = copy_model_param(model)
        if 'metasgd' in P.mode and not hasattr(P, 'moving_inner_lr'):
            P.moving_inner_lr = copy_model_param(None, params=P.inner_lr)

        if hasattr(P, 'moving_inner_lr'):
            inner_lr_ema = P.moving_inner_lr
        else:
            inner_lr_ema = P.inner_lr

    for task_idx, (train_inputs, train_target, test_inputs, test_target) \
            in enumerate(zip(*batch['train'], *batch['test'])):

        train_input1, train_input2 = train_inputs[0], train_inputs[1]
        train_input1, train_input2 = train_input1.to(device, non_blocking=True), train_input2.to(device, non_blocking=True)
        train_target = train_target.to(device, non_blocking=True)

        test_input1, test_input2 = test_inputs[0], test_inputs[1]
        test_input1, test_input2 = test_input1.to(device, non_blocking=True), test_input2.to(device, non_blocking=True)
        test_target = test_target.to(device, non_blocking=True)

        if P.ema:
            params_teacher, _ = maml_inner_adapt(
                model, criterion, train_input1, train_target, inner_lr_ema, P.inner_steps, inner_update_type = P.inner_update_type,
                first_order=True, params=P.moving_average
            )
        import copy
        deep_copy_model = copy.deepcopy(model)
        adv_train_input1 = attack_module.perturb(deep_copy_model, train_input1, train_input2)
        adv_train_input2 = attack_module.perturb(deep_copy_model, train_input2, train_input1)
        
        params1, loss_train = maml_inner_adapt(
                model, criterion, adv_train_input1, train_target, P.inner_lr, P.inner_steps, inner_update_type = P.inner_update_type, first_order=P.first_order,
            )
        params2, loss_train = maml_inner_adapt(
                model, criterion, adv_train_input2, train_target, P.inner_lr, P.inner_steps, inner_update_type = P.inner_update_type, first_order=P.first_order
            )

        """ outer loss aggregate """
        deep_copy_model = copy.deepcopy(model)
        if P.qry_attack:
            """ attack encoder and linear """
            if P.dynamic_attack:
                adv_test_input1 = attack_module.perturb(deep_copy_model, test_input1, test_input2, inner_update_type=P.inner_update_type, params=params1)
                adv_test_input2 = attack_module.perturb(deep_copy_model, test_input2, test_input1, inner_update_type=P.inner_update_type, params=params2)
          
                qry = (test_input1, test_input2)
                adv = (adv_test_input1, adv_test_input2)
                (logits_qry1, logits_qry2), (logits_adv1, logits_adv2), _, (z_qry1, z_qry2), (z_adv1, z_adv2), _ = model(qry, adv=adv, sprt=None, qry_num=2, adv_num=2, sprt_num=0, params=params1, params2=params2, feat=True, inner_update_type=P.inner_update_type)
            
        loss_test = 0
        if P.qry_attack:
            
            ce_loss_clean1 = criterion(logits_qry1, test_target)
            if 'cec1' in P.loss_type:
                loss_test += ce_loss_clean1
            clean_loss = ce_loss_clean1
            if 'cec2' in P.loss_type:
                ce_loss_clean2 = criterion(logits_qry2, test_target)
                loss_test += ce_loss_clean2

            if 'ceadv1' in P.loss_type:
                ce_loss_adv1 = criterion(logits_adv1, test_target)
                loss_test += ce_loss_adv1
            if 'ceadv2' in P.loss_type:
                ce_loss_adv2 = criterion(logits_adv2, test_target)
                loss_test += ce_loss_adv2

            if 'kl' in P.loss_type:
                kl_loss = nn.KLDivLoss(size_average=False)
                if 'kl1' in P.loss_type:
                    loss_test += P.advw * kl_loss(F.log_softmax(logits_adv1, dim=1), F.softmax(logits_qry1, dim=1)) * (1./logits_qry1.size(0))
                if 'kl2' in P.loss_type:
                    loss_test += P.advw * kl_loss(F.log_softmax(logits_adv2, dim=1), F.softmax(logits_qry2, dim=1)) * (1./logits_qry2.size(0))
            
            if 'cos' in P.loss_type:
                cos_loss = nn.CosineSimilarity()
                if 'coscc' in P.loss_type:
                    loss_test += 1.0 - cos_loss(z_qry1, z_qry2).mean()
                if 'cosadvc' in P.loss_type:
                    loss_test += 1.0 - 0.5 * (cos_loss(z_qry1, z_adv1).mean() + cos_loss(z_qry2, z_adv2).mean())
                if 'cosadvadv' in P.loss_type:
                    loss_test += 1.0 -cos_loss(z_adv1, z_adv2).mean()

        inner_loss += loss_train.item() / num_tasks
        outer_loss += loss_test / num_tasks
        class_loss_total += clean_loss.item() / num_tasks
        if not P.regression:
            acc += accuracy(logits_qry1, test_target, topk=(1,))[0].item() / num_tasks
        if P.subset and task_idx+1 > P.full_batch:
            break
    loss = outer_loss

    """ outer gradient step """
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if P.ema:
        """ exponential weight average """
        param_ema(P, model)

    """ track stat """
    metric_logger.meters['batch_time'].update(time.time() - stime)
    metric_logger.meters['meta_train_cls'].update(inner_loss)
    metric_logger.meters['meta_test_cls'].update(outer_loss.item())
    metric_logger.meters['train_acc'].update(acc)
    metric_logger.meters['class_loss_total'].update(class_loss_total)

    if step % P.print_step == 0:
        logger.log_dirname(f"Step {step}")
        logger.scalar_summary('train/meta_train_cls',
                                metric_logger.meta_train_cls.global_avg, step)
        logger.scalar_summary('train/meta_test_cls',
                                metric_logger.meta_test_cls.global_avg, step)
        logger.scalar_summary('train/train_acc',
                                metric_logger.train_acc.global_avg, step)
        logger.scalar_summary('train/class_loss_total',
                                metric_logger.class_loss_total.global_avg, step)
        logger.scalar_summary('train/batch_time',
                                metric_logger.batch_time.global_avg, step)

        logger.log('[TRAIN] [Step %3d] [Time %.3f] [Data %.3f] '
                    '[MetaTrainLoss %f] [MetaTestLoss %f] [MetaTestClass %f]' %
                    (step, metric_logger.batch_time.global_avg, metric_logger.data_time.global_avg,
                        metric_logger.meta_train_cls.global_avg, metric_logger.meta_test_cls.global_avg, 
                        metric_logger.class_loss_total.global_avg))
