import time
import os
import importlib
import numpy as np
import torch
import torch.nn as nn
import torchvision
import math
import csv
from pathlib import Path

from utils.logging import AverageMeter, ProgressMeter
from utils.eval import accuracy
from utils.hw import hw_loss, hw_flops_loss
from utils.model import map_shortcut_rate
from utils.utils import rate_act_func
import torch.nn.functional as F
from sharpness.eigenvalues import estimate_largest_eigenvector
from utils.s2ap import trades_loss


# TODO: add adversarial accuracy.
def train(
        model, device, train_loader, criterion, optimizer, epoch, args, writer, frozen_gamma, s2ap_adversary
):
    warmup_epochs = args.warmup_epochs
    if epoch < warmup_epochs:
        print(
            " ->->->->->->->->->-> One epoch with Nat Warm-Up [Warmup Epoch: {}] <-<-<-<-<-<-<-<-<-<-<-<-<-<-<-".format(
                epoch)
        )
    else:
        print(
            " ->->->->->->->->->-> One epoch with TRADES-s2ap [AT Epoch: {}] <-<-<-<-<-<-<-<-<-<-".format(
                epoch - warmup_epochs)
        )

    batch_time = AverageMeter("Time", ":6.3f")
    losses = AverageMeter("Loss", ":.4f")
    hw_losses = AverageMeter("HW-Loss", ":.4f")
    top1 = AverageMeter("Acc_1", ":6.2f")
    top5 = AverageMeter("Acc_5", ":6.2f")
    info_list = [batch_time, losses, top1, top5] if args.exp_mode == 'pretrain' else [batch_time, losses, hw_losses,
                                                                                      top1, top5]
    progress = ProgressMeter(
        len(train_loader),
        info_list,
        prefix="Epoch: [{}]".format(epoch),
    )

    # adv train infolist
    train_top1_adv = AverageMeter("Adv_Acc_1", ":6.2f")
    train_adv_losses = AverageMeter("Adv_Loss", ":.4f")
    train_adv_info_list = [train_top1_adv, train_adv_losses]
    train_adv_progress = ProgressMeter(
        len(train_loader),
        train_adv_info_list,
        prefix="Epoch: [{}]".format(epoch),
    )

    model.train()
    end = time.time()

    dataloader = train_loader

    if args.adv_loss == 'trades':
        criterion_kl = nn.KLDivLoss(size_average=False)

    # swa update iter tracker
    update_iter = 0
    for i, data in enumerate(dataloader):

        images, target = data[0].to(device), data[1].to(device)

        # basic properties of training data
        if i == 0:
            print(
                images.shape,
                target.shape,
                f"Batch_size from args: {args.batch_size}",
                "lr: {:.5f}".format(optimizer.param_groups[0]["lr"]),
            )
            print(f"Training images range: {[torch.min(images), torch.max(images)]}")

        output = model(images)

        x_adv = trades_loss(
            model=model,
            x_natural=images,
            y=target,
            device=device,
            optimizer=optimizer,
            step_size=args.step_size,
            epsilon=args.epsilon,
            perturb_steps=args.num_steps,
            beta=args.beta,
            clip_min=args.clip_min,
            clip_max=args.clip_max,
            distance=args.distance,
        )

        # perturb model
        if epoch >= args.s2ap_warmup:
            s2ap = s2ap_adversary.calc_s2ap(inputs_adv=x_adv,
                                         inputs_clean=images,
                                         targets=target,
                                         beta=args.beta,
                                         exp_mode=args.exp_mode,
                                         k=args.k)
            if 'ada' in args.trainer:
                s2ap_adversary.perturb(s2ap[0], k=args.k, exp_mode=args.exp_mode, s_max=s2ap[1], s_min=s2ap[2])
            else:
                s2ap_adversary.perturb(s2ap, k=args.k, exp_mode=args.exp_mode)

        if args.adv_loss == 'pgd':
            natural_criterion = nn.CrossEntropyLoss()
            # zero gradient
            optimizer.zero_grad()
            # calculate robust loss
            logits = model(x_adv)
            loss = natural_criterion(logits, target)

        elif args.adv_loss == 'trades':
            # compute loss after perturbing the model
            optimizer.zero_grad()
            logits = model(images)
            loss_natural = criterion(logits, target)
            loss_robust = (1.0 / len(images)) * criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                                             F.softmax(model(images), dim=1))
            # calculate natural loss and backprop
            loss = loss_natural + args.beta * loss_robust
            if args.eigenvalues and (i%10==0): 
                _, lambda_est = estimate_largest_eigenvector(model, criterion1=criterion, criterion2=criterion_kl, images=images, x_adv=x_adv, beta=args.beta, labels=target, v=None, steps=5)
                with open(os.path.join(Path(args.result_dir), args.exp_name, args.exp_mode, "eigenvalues.csv"), mode='a', newline='') as csv_file:
                    writer2 = csv.writer(csv_file)
                    writer2.writerow([epoch, i, lambda_est.item()])  


        if args.soft_hw:
            if args.prune_reg == 'channel':
                hw_loss_func = hw_flops_loss
            else:
                hw_loss_func = hw_loss
            gamma, loss_hw, _ = hw_loss_func(
                model=model,
                device=device,
                optimizer=optimizer,
                args=args,
                epoch=epoch,
                frozen_gamma=frozen_gamma
            )
            hw_losses.update(loss_hw.item(), images.size(0))

            loss = loss + gamma * loss_hw

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # update robust train loss
        train_adv_acc1, _ = accuracy(model(x_adv), target, topk=(1, 5))
        train_adv_losses.update(loss.item(), images.size(0))
        train_top1_adv.update(train_adv_acc1[0], images.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # restore
        if args.trainer == 's2ap' and epoch > args.s2ap_warmup:
            if 'ada' in args.trainer:
                s2ap_adversary.restore(s2ap[0], k=args.k, exp_mode=args.exp_mode, s_max=s2ap[1], s_min=s2ap[2])
            else:
                s2ap_adversary.restore(s2ap, k=args.k, exp_mode=args.exp_mode)

        # Map shortcut layer rates for channel prune:
        if args.prune_reg == 'channel':
            map_shortcut_rate(model, args)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
            progress.write_to_tensorboard(
                writer, "train", epoch * len(train_loader) + i
            )
            train_adv_progress.write_to_tensorboard(
                writer, "train", epoch * len(train_loader) + i
            )

        # write a sample of training images to tensorboard (helpful for debugging)
        if i == 0:
            writer.add_image(
                "training-images",
                torchvision.utils.make_grid(images[0: len(images) // 4]),
            )

    for m_name, m in model.named_modules():
        if hasattr(m, "k_rate"):
            k = rate_act_func(m.k_score.data, m.k_min)
            print(f'{m_name}: {k.data}')
