# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
"""
Train and eval functions used in main.py
"""
import math
import sys
from typing import Iterable, Optional

import torch

from timm.data import Mixup
from timm.utils import accuracy, ModelEma

from losses import DistillationLoss
import utils

import numpy as np
import os

import torchattacks

from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method
from cleverhans.torch.attacks.projected_gradient_descent import (
    projected_gradient_descent,
)

def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    set_training_mode=True, args = None):
    set_training_mode = False
    model.train(set_training_mode)
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 100

    #eps_list = [0.0, 0.0156, 0.0313, 0.047, 0.0627, 0.0784, 0.0941, 0.1098] #[0.0, 0.01, 0.05, 0.1]
    eps_list = [int(10000*a/256)/10000 for a in np.arange(0,29,4)]
    print('********epoch: {}, epsilon: {}'.format(epoch, eps_list[epoch%len(eps_list)]))
    

    model_without_ddp = model
    if args.distributed:
        model_without_ddp = model.module

    model.eval_mode = True
    batch_idx = -1
    
    epsilon = eps_list[epoch] #%len(eps_list)] #0.1 #25
    '''
    if epoch<len(eps_list):
        #fgsm_atk = fast_gradient_method(model_without_ddp, eps=epsilon)#, np.inf)
        fgsm_atk = torchattacks.FGSM(model_without_ddp, eps=epsilon)#, np.inf)
    elif epoch<2*len(eps_list):
        pgd_atk = torchattacks.PGD(model_without_ddp, eps=epsilon, alpha=epsilon/5, steps=10)#, np.inf)
    else:
        df_atk = DeepFool(model_without_ddp, steps=10)#, np.inf)
    '''

    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
        batch_idx += 1
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        model_without_ddp.batch_idx = batch_idx
        model_without_ddp.batch_cnt = batch_idx
        model_without_ddp.target = targets
        model_without_ddp.save_images = False

        batch_size = samples.shape[0]
            
        #if args.bce_loss:
        #    targets = targets.gt(0.0).type(targets.dtype)
        #print('samples.shape: {}, targets.shape: {}'.format(samples.shape, targets.shape))

        with torch.cuda.amp.autocast():
            #samples = model.forward_features(samples)
            #model_fs = model_without_ddp
            #model_fs.py_idx =  np.random.randint(0,14,[1])
            #model_fs.px_idx =  np.random.randint(0,14,[1])
            #'''
            #if epoch<len(eps_list):
                #samples_ = fgsm_atk(samples, targets)
                #samples_ = fast_gradient_method(model_without_ddp, samples, epsilon, np.inf)
                #elif epoch<2*len(eps_list):
                #samples_ = pgd_atk(samples, targets)
            samples_ = projected_gradient_descent(model_without_ddp, samples, epsilon, epsilon/5, 10, np.inf)
            #else:
            #    samples_ = df_atk(samples, targets)
            #q = torch.mean(torch.abs(samples-samples_))
            #'''
            #samples_ = samples
            #print('delta: {}'.format(q.item()))
            #model_without_ddp.py_idx =  np.random.randint(0,14,[1])
            #model_without_ddp.px_idx =  np.random.randint(0,14,[1])
            outputs = model(samples_) #, targets)


        loss_value = 0 #loss.item()
        acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=0.1)
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)


    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


@torch.no_grad()
def evaluate(data_loader, model, device, epoch, args_distributed):
    criterion = torch.nn.CrossEntropyLoss()

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    # switch to evaluation mode
    model.eval()

    batch_idx = -1
    try:
        os.mkdir('saved_images/')
    except:
        pass

    model_without_ddp = model
    if args_distributed:
        model_without_ddp = model.module
    model_without_ddp.epoch = epoch
    try:
        os.mkdir('saved_images/'+str(epoch))
    except:
        pass


    for images, target in metric_logger.log_every(data_loader, 10, header):
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        batch_idx += 1
        model_without_ddp.batch_idx = batch_idx
        model_without_ddp.target = target
        if epoch%10==9:
            model_without_ddp.save_images = True #False #True
        else:
            model_without_ddp.save_images = False #True

        loss = 0
        #print('model: {}'.format(model))
        # compute output
        with torch.cuda.amp.autocast():
            output = model(images)
            #loss = criterion(output, target)
            loss = criterion(output[:,:,0], target)
            for fix_idx in np.arange(1, model_without_ddp.num_fixations):#1, model.num_fixations):
                #loss = 0.9*loss + criterion(output[:,:,fix_idx], target)
                loss += criterion(output[:,:,fix_idx], target)

        #acc1, acc5 = accuracy(output, target, topk=(1, 5))
        acc1, acc5 = accuracy(output[:,:,-1].squeeze(), target, topk=(1, 5))

        batch_size = images.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
