# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
"""
Train and eval functions used in main.py
"""
import math
import sys
from typing import Iterable, Optional

import torch

from timm.data import Mixup
from timm.utils import accuracy, ModelEma

from losses import DistillationLoss
import utils

import numpy as np
import os


from torchvision import transforms



import matplotlib.pyplot as plt
import numpy as np
from scipy.misc import face
from scipy.ndimage import zoom
from scipy.special import logsumexp
import torch

import deepgaze_pytorch

DEVICE = 'cuda'



def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    set_training_mode=True, args = None):
    model.train(set_training_mode)
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10


    model_without_ddp = model
    if args.distributed:
        model_without_ddp = model.module

    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        model_without_ddp.target = targets
        model_without_ddp.save_images = False

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)
            
        if args.bce_loss:
            targets = targets.gt(0.0).type(targets.dtype)
                    
        with torch.cuda.amp.autocast():
            outputs = model(samples)
            #loss = criterion(samples, outputs, targets)
            loss = criterion(samples, outputs[:,:,0], targets)
            for fix_idx in np.arange(1, model_without_ddp.num_fixations): #1, model.num_fixations):
                #loss = 0.9*loss + criterion(samples, outputs[:,:,fix_idx], targets)
                loss += criterion(samples, outputs[:,:,fix_idx], targets)

        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        optimizer.zero_grad()

        # this attribute is added by timm on one optimizer (adahessian)
        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
        loss_scaler(loss, optimizer, clip_grad=max_norm,
                    parameters=model.parameters(), create_graph=is_second_order)

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


@torch.no_grad()
def evaluate(data_loader, model, device, epoch, args_distributed):
    criterion = torch.nn.CrossEntropyLoss()

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    # switch to evaluation mode
    model.eval()

    batch_idx = -1
    try:
        os.mkdir('saved_images/')
    except:
        pass

    model_without_ddp = model
    if args_distributed:
        model_without_ddp = model.module
    model_without_ddp.epoch = epoch
    try:
        os.mkdir('saved_images/'+str(epoch))
    except:
        pass


    STATS_DIR = 'Generated_results/Medium_sized_backbone/Deep_Gaze_Files/' #'Stats_Files/'
    try:
        os.mkdir(STATS_DIR)
    except:
        pass
    
    rescale = transforms.Resize(1024, interpolation=3)
    valid_subset = np.load('Generated_results/Medium_sized_backbone/RG_Files/valid_subet.npy')
    print(np.sum(valid_subset))
    #'''
    model_dg = deepgaze_pytorch.DeepGazeIII(pretrained=True).to(DEVICE)
    centerbias = np.zeros((1024, 1024))
    centerbias_tensor = torch.tensor([centerbias]).to(DEVICE)
    #'''

    for images, target in metric_logger.log_every(data_loader, 10, header):
        # Add noise
        B = images.shape[0]
        batch_idx+=1

        if not valid_subset[batch_idx]:
            metric_logger.update(loss=0)#loss.item())
            continue

        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)


        img_idx = 0

        image_tensor = images #[img_idx].unsqueeze(0)
        image_tensor = rescale(image_tensor)

        init_x = 110
        init_y = 110
        fixation_history_x = np.append(np.random.randint(0,1024,[3]), 110)
        fixation_history_y = np.append(np.random.randint(0,1024,[3]), 110)

        fix_list = []
        norm_x = np.floor((init_x/1024)*16)
        norm_y = np.floor((init_y/1024)*16)
        fix_list.append([norm_x, norm_y])

        for fix_idx in range(model_without_ddp.num_fixations-1):

            x_hist_tensor = torch.tensor([fixation_history_x[model_dg.included_fixations]]).to(DEVICE)
            y_hist_tensor = torch.tensor([fixation_history_y[model_dg.included_fixations]]).to(DEVICE)
            #x_hist_tensor = torch.tensor([np.flip(fixation_history_x)]).to(DEVICE)
            #y_hist_tensor = torch.tensor([np.flip(fixation_history_y)]).to(DEVICE)

            #print('image_tensor.shape: {}, centerbias_tensor.shape: {}'.format(image_tensor.shape, centerbias_tensor.shape))
            log_density_prediction = model_dg(image_tensor, centerbias_tensor, x_hist_tensor, y_hist_tensor)

            #print(log_density_prediction.shape)
            loc = torch.argmax(log_density_prediction[0])
            loc_y = loc//1024
            loc_x = loc%1024
            #print('{}: {}:   {}, {}'.format(fix_idx, loc, loc_y, loc_x))
            fixation_history_x = np.append(fixation_history_x, loc_x.cpu().numpy())
            fixation_history_y = np.append(fixation_history_y, loc_y.cpu().numpy())
            
            norm_x = int(np.floor((loc_x.cpu().numpy()/1024)*16))
            norm_y = (np.floor((loc_y.cpu().numpy()/1024)*16))
            fix_list.append([norm_x, norm_y])

        'images = MI_t(images, model_without_ddp.crop_array[0:B,0], model_without_ddp.crop_array[0:B,1])'
        print('fix_list: {}'.format(fix_list))
        model_without_ddp.fix_list = fix_list
        # Insert hook


        model_without_ddp.batch_idx = batch_idx
        model_without_ddp.batch_cnt = batch_idx
        model_without_ddp.target = target
        model_without_ddp.save_images = False #True

        loss = 0
        #print('model: {}'.format(model))
        # compute output
        with torch.cuda.amp.autocast():
            output = model(images)
            #loss = criterion(output, target)
            loss = criterion(output[:,:,0], target)
            for fix_idx in np.arange(1, model_without_ddp.num_fixations):#1, model.num_fixations):
                #loss = 0.9*loss + criterion(output[:,:,fix_idx], target)
                loss += criterion(output[:,:,fix_idx], target)

        #acc1, acc5 = accuracy(output, target, topk=(1, 5))
        acc1, acc5 = accuracy(output[:,:,-1], target, topk=(1, 5))

        batch_size = images.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)

    if model_without_ddp.flag_stats:
        np.save(STATS_DIR+'without_IOR_Random_Fixations'+'_'+str(model_without_ddp.num_fixations)+'.npy', model_without_ddp.ind_stats)
    #if True: #model_without_ddp.flag_stats:
    #    np.save(STATS_DIR+'Guided_Fixations'+'_'+str(model_without_ddp.num_fixations)+'.npy', model_without_ddp.ind_stats)
        #np.save(STATS_DIR+'Random_Fixations'+'_'+str(model_without_ddp.num_fixations)+'.npy', model_without_ddp.ind_stats)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
