import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from train import copy_model_param, param_ema
from train.gradient_based import maml_inner_adapt, maximize_inner_adapt
from evals import accuracy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from utils import jsd_loss

def check(P):
    filename_with_today_date = True
    assert P.num_shots_global == 0
    return filename_with_today_date

def maml_barlow_step(P, step, model, criterion, optimizer, batch, metric_logger, logger, attack_module):

    stime = time.time()
    model.train()

    inner_loss = 0.
    acc = 0.
    outer_loss = torch.tensor(0., device=device)
    class_loss_total = 0.
    v1_total_loss = torch.tensor(0., device=device)
    v2_total_loss = torch.tensor(0., device=device)
    cross_total_loss = torch.tensor(0., device=device)
    num_tasks = batch['test'][1].size(0)

    for task_idx, (train_inputs, train_target, test_inputs, test_target) \
            in enumerate(zip(*batch['train'], *batch['test'])):

        train_input1, train_input2, input0 = train_inputs[0], train_inputs[1], train_inputs[2]
        train_input1, train_input2, input0 = train_input1.to(device, non_blocking=True), train_input2.to(device, non_blocking=True), input0.to(device, non_blocking=True)
        train_target = train_target.to(device, non_blocking=True)

        test_input1, test_input2, test0 = test_inputs[0], test_inputs[1], test_inputs[2]
        test_input1, test_input2, test0 = test_input1.to(device, non_blocking=True), test_input2.to(device, non_blocking=True), test0.to(device, non_blocking=True)
        test_target = test_target.to(device, non_blocking=True)

        if P.no_aug:
            train_input1 = input0
        if P.img_aug_only:
            train_input1 = torch.cat((train_input1, train_input2))
            train_target = torch.cat((train_target, train_target))
        params1, loss_train = maml_inner_adapt(
                model, criterion, train_input1, train_target, P.inner_lr, P.inner_steps, inner_update_type = P.inner_update_type, first_order=P.first_order
            )
        if not P.inner_ablation:
            params2, loss_train2 = maml_inner_adapt(
                    model, criterion, train_input2, train_target, P.inner_lr, P.inner_steps, inner_update_type = P.inner_update_type, first_order=P.first_order
            )
        else:
            params2, loss_train2 = params1, loss_train


        """ outer loss aggregate """
        if P.adv:
            import copy
            deep_copy_model = copy.deepcopy(model)
            if P.qry_attack:
                """ attack encoder and linear """
                if P.class_attack:
                    adv_test_input1 = attack_module.perturb(deep_copy_model, test_input1, test_input2, inner_update_type=P.inner_update_type, params=params1)
                    if '2' in P.loss_type:
                        adv_test_input2 = attack_module.perturb(deep_copy_model, test_input2, test_input1, inner_update_type=P.inner_update_type, params=params2) 
                else:
                    if P.param_attack:
                        adv_test_input1 = attack_module.perturb(deep_copy_model, test_input1, test_input2, inner_update_type=P.inner_update_type, params=params1, param2=params2)
                        adv_test_input2 = attack_module.perturb(deep_copy_model, test_input2, test_input1, inner_update_type=P.inner_update_type, params=params2, param2=params1)
                    else:
                        adv_test_input1 = attack_module.perturb(deep_copy_model, test_input1, test_input2, inner_update_type=P.inner_update_type, params=params1)
                        adv_test_input2 = attack_module.perturb(deep_copy_model, test_input2, test_input1, inner_update_type=P.inner_update_type, params=params2)
                    
            if P.qry_attack:
                
                if '2' in P.loss_type:
                    qry = (test_input1, test_input2)
                    adv = (adv_test_input1, adv_test_input2)
                    (logits_qry1, logits_qry2), (logits_adv1, logits_adv2), _, (z_qry1, z_qry2), (z_adv1, z_adv2), _ = model(qry, adv=adv, sprt=None, qry_num=2, adv_num=2, sprt_num=0, params=params1, params2=params2, feat=True, inner_update_type=P.inner_update_type)
                else:
                    qry = test_input1
                    adv = adv_test_input1
                    (logits_qry1), (logits_adv1), _, (z_qry1), (z_adv1), _ = model(qry, adv=adv, sprt=None, qry_num=1, adv_num=1, sprt_num=0, params=params1, params2=params2, feat=True, inner_update_type=P.inner_update_type)
        else:
            logits_qry1 = model(test_input1, params=params1, inner_update_type=P.inner_update_type)
        
        loss_test = 0
        v1_loss = 0
        v2_loss = 0
        if P.qry_attack:
            
            ce_loss_clean1 = criterion(logits_qry1, test_target)
            if 'cec1' in P.loss_type:
                loss_test += ce_loss_clean1
                v1_loss += ce_loss_clean1.item()
            clean_loss = ce_loss_clean1
            if 'cec2' in P.loss_type:
                ce_loss_clean2 = criterion(logits_qry2, test_target)
                loss_test += ce_loss_clean2
                v2_loss += ce_loss_clean2.item()

            if 'ceadv1' in P.loss_type:
                ce_loss_adv1 = criterion(logits_adv1, test_target)
                loss_test += ce_loss_adv1
                v1_loss += ce_loss_adv1.item()
            if 'ceadv2' in P.loss_type:
                ce_loss_adv2 = criterion(logits_adv2, test_target)
                loss_test += ce_loss_adv2
                v2_loss += ce_loss_adv2.item()

            if 'kl' in P.loss_type:
                kl_loss = nn.KLDivLoss(size_average=False)
                if 'kl1' in P.loss_type:
                    kl1_loss = P.advw * kl_loss(F.log_softmax(logits_adv1, dim=1), F.softmax(logits_qry1, dim=1)) * (1./logits_qry1.size(0))
                    loss_test += kl1_loss
                    v1_loss += kl1_loss.item()
                if 'kl2' in P.loss_type:
                    kl2_loss = P.advw * kl_loss(F.log_softmax(logits_adv2, dim=1), F.softmax(logits_qry2, dim=1)) * (1./logits_qry2.size(0))
                    loss_test += kl2_loss
                    v2_loss += kl2_loss.item()

            if 'cos' in P.loss_type:
                cos_loss = nn.CosineSimilarity()
                if 'coscc' in P.loss_type:
                    cross_loss = 1.0 - cos_loss(z_qry1, z_qry2).mean()
                    cross_total_loss += cross_loss.item()/ num_tasks
                    loss_test += cross_loss
                if 'cosadvc' in P.loss_type:
                    cross_loss = 1.0 - 0.5 * (cos_loss(z_qry1, z_adv1).mean() + cos_loss(z_qry2, z_adv2).mean())
                    cross_total_loss += cross_loss.item()/ num_tasks
                    loss_test += cross_loss
                if 'cosadvadv' in P.loss_type:
                    cross_loss = 1.0 -cos_loss(z_adv1, z_adv2).mean()
                    cross_total_loss += cross_loss.item()/ num_tasks
                    loss_test += cross_loss

        else:
            loss_test = criterion(logits_qry, test_target) 

        inner_loss += (loss_train.item() +loss_train2.item())/ (2.0*num_tasks)
        outer_loss += loss_test / num_tasks
        class_loss_total += clean_loss.item() / num_tasks
        v1_total_loss += v1_loss / num_tasks
        v2_total_loss += v2_loss / num_tasks
        if not P.regression:
            acc += accuracy(logits_qry1, test_target, topk=(1,))[0].item() / num_tasks
        if P.subset and task_idx+1 > P.full_batch:
            break
    loss = outer_loss

    """ outer gradient step """
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if P.ema:
        """ exponential weight average """
        param_ema(P, model)

    """ track stat """
    metric_logger.meters['batch_time'].update(time.time() - stime)
    metric_logger.meters['meta_train_cls'].update(inner_loss)
    metric_logger.meters['meta_test_cls'].update(outer_loss.item())
    metric_logger.meters['class_loss_total'].update(class_loss_total)
    metric_logger.meters['train_acc'].update(acc)
    metric_logger.meters['v1_loss'].update(v1_total_loss)
    metric_logger.meters['v2_loss'].update(v2_total_loss)
    metric_logger.meters['cross_loss'].update(cross_total_loss)
    
    if logger is not None:
        logger.write_log_nohead({
            'step': step,
            'train/meta_train': metric_logger.meta_train_cls.global_avg,
            'train/meta_test': metric_logger.meta_test_cls.global_avg,
            'train/train_acc': metric_logger.train_acc.global_avg,
            'train/class_loss': metric_logger.class_loss_total.global_avg,
            'train/v1_loss': metric_logger.v1_loss.global_avg,
            'train/v2_loss': metric_logger.v2_loss.global_avg,
            'train/cross_loss': metric_logger.cross_loss.global_avg,
        }, step=step)
