import time
from collections import OrderedDict

import torch
import torch.nn as nn

from common.utils import is_resume
from data.shapenet1d import AzimuthLoss
from utils import MetricLogger, save_checkpoint, save_checkpoint_step
from attack.representation_adv import RepresentationAdv
from attack.attack_lib import AttackPGD, AttackFGSM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def learnable_lr(P, model, optimizer):
    if P.mode == 'metasgd':
        P.inner_lr = OrderedDict(
            (name, torch.tensor(P.inner_lr, dtype=param.dtype, device=device, requires_grad=True))
            for (name, param) in model.meta_named_parameters()
        )
        optimizer.add_param_group({'params': P.inner_lr.values()})


def meta_trainer(P, train_func, test_func, model, optimizer, train_loader, test_loader, logger):
    kwargs = {}
    kwargs_test = {}

    metric_logger = MetricLogger(delimiter="  ")

    """ check lr is learnable """
    learnable_lr(P, model, optimizer)

    """ resume option """
    is_best, start_step, best, acc = is_resume(P, model, optimizer)

    bestadv, advacc = 0, 0
    """ define loss function """
    if P.dataset == 'pose':
        criterion = nn.MSELoss()
    elif P.dataset == 'shapenet':
        criterion = AzimuthLoss()
    else:
        criterion = nn.CrossEntropyLoss()

    if P.adv:
        if P.barlow or P.selfsup:
            if P.class_attack:
                attack_module = AttackPGD(epsilon=P.epsilon,
                    alpha=P.alpha,
                    min_val=P.min_val,
                    max_val=P.max_val,
                    max_iters=P.max_iters,
                    attack_type=P.attack_type,
                    random_start=P.random_start,
                    )
            else:
                attack_module = RepresentationAdv(
                epsilon=P.epsilon,
                alpha=P.alpha,
                min_val=P.min_val,
                max_val=P.max_val,
                max_iters=P.max_iters,
                attack_type=P.attack_type,
                attack_loss_type=P.attack_loss_type,
                random_start=P.random_start,
                )
        elif P.adml:
            attack_module = AttackPGD(
                epsilon=P.epsilon,
                alpha=P.alpha,
                min_val=P.min_val,
                max_val=P.max_val,
                max_iters=P.max_iters,
                attack_type=P.attack_type,
                random_start=P.random_start,
                )

            #attack_module = AttackFGSM(
            #    epsilon=0.2,
            #)
        else:
            pgd_attack_module = AttackPGD(
            epsilon=P.epsilon,
            alpha=P.alpha,
            min_val=P.min_val,
            max_val=P.max_val,
            max_iters=P.max_iters,
            attack_type=P.attack_type,
            random_start=P.random_start,
            )
    else:
        attack_module = None
    kwargs= {'attack_module': attack_module}
    if P.adv:
        pgd_attack_module = AttackPGD(
            epsilon=8.0/255.0,
            alpha=8.0/2550.0,
            min_val=0.0,
            max_val=1.0,
            max_iters=20,
            attack_type=P.attack_type,
            random_start=P.random_start,
            )
        kwargs_test = {'attack_module': pgd_attack_module}

    """ training start """
    
    if P.subsampling:
        step = 0
        #import pdb
        #pdb.set_trace()
        while step<=P.outer_steps:
            for i, train_batch in enumerate(train_loader):
                if (not P.use_sampler) and (i == P.limit_train_batches):
                    print(step, i)
                    break
                step += 1
                stime = time.time()
                
                metric_logger.meters['data_time'].update(time.time() - stime)

                train_func(P, step, model, criterion, optimizer, train_batch,
                        metric_logger=metric_logger, logger=logger, **kwargs)

                """ evaluation & save the best model """
                if step % P.eval_step == 0:
                    if not (attack_module is None):
                        acc, advacc = test_func(P, model, test_loader, criterion, step, logger=logger, **kwargs_test)
                    else:
                        acc = test_func(P, model, test_loader, criterion, step, logger=logger, **kwargs_test)

                    
                    if not (attack_module is None):
                        if bestadv <advacc:
                            bestadv = advacc
                            save_checkpoint(P, step, best, model.state_dict(),
                                            optimizer.state_dict(), logger.log_dir, is_best=True)
                        if best < acc:
                            best = acc

                        logger.write_log_nohead({
                                'eval/best_acc': best,
                                'eval/acc': acc,
                                'eval/bestadv_acc': bestadv,
                                'eval/advacc': advacc}, step=step)

                        print('[EVAL] [Step %3d] [Acc %5.2f] [AdvAcc %5.2f] [Best %5.2f][AdvBest %5.2f]' % (step, acc, advacc, best, bestadv))
                    else:
                        if best < acc:
                            best = acc
                            save_checkpoint(P, step, best, model.state_dict(),
                                            optimizer.state_dict(), logger.log_dir, is_best=True)

                        logger.write_log_nohead({
                                'eval/best_acc': best,
                                'eval/acc': acc}, step=step)

                        print('[EVAL] [Step %3d] [Acc %5.2f] [Best %5.2f]' % (step, acc, best))

                """ save model per save_step steps"""
                if step % P.save_step == 0:
                    save_checkpoint_step(P, step, best, model.state_dict(),
                                        optimizer.state_dict(), logger.log_dir)

        """ save last model"""
        save_checkpoint(P, P.outer_steps, best, model.state_dict(),
                            optimizer.state_dict(), logger.log_dir)

    else:
        for step in range(start_step, P.outer_steps + 1):

            stime = time.time()
            train_batch = next(train_loader)
            metric_logger.meters['data_time'].update(time.time() - stime)

            train_func(P, step, model, criterion, optimizer, train_batch,
                    metric_logger=metric_logger, logger=logger, **kwargs)

            """ evaluation & save the best model """
            if step % P.eval_step == 0:
                if not (attack_modele is None):
                    advacc, acc = test_func(P, model, test_loader, criterion, step, logger=logger, **kwargs_test)
                else:
                    acc = test_func(P, model, test_loader, criterion, step, logger=logger, **kwargs_test)
                if not (attack_modele is None):
                    if bestadv <advacc:
                        bestadv = advacc
                        save_checkpoint(P, step, best, model.state_dict(),
                                        optimizer.state_dict(), logger.log_dir, is_best=True)
                    if best < acc:
                        best = acc

                    logger.write_log_nohead({
                            'eval/best_acc': best,
                            'eval/acc': acc,
                            'eval/bestadv_acc': bestadv,
                            'eval/advacc': advacc}, step=step)

                    print('[EVAL] [Step %3d] [Acc %5.2f] [AdvAcc %5.2f] [Best %5.2f][AdvBest %5.2f]' % (step, acc, advacc, best, bestadv))
                else:
                    if best < acc:
                        best = acc
                        save_checkpoint(P, step, best, model.state_dict(),
                                        optimizer.state_dict(), logger.log_dir, is_best=True)

                    logger.write_log_nohead({
                            'eval/best_acc': best,
                            'eval/acc': acc}, step=step)

                    print('[EVAL] [Step %3d] [Acc %5.2f] [Best %5.2f]' % (step, acc, best))
            """ save model per save_step steps"""
            if step % P.save_step == 0:
                save_checkpoint_step(P, step, best, model.state_dict(),
                                    optimizer.state_dict(), logger.log_dir)

        """ save last model"""
        save_checkpoint(P, P.outer_steps, best, model.state_dict(),
                        optimizer.state_dict(), logger.log_dir)
