# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
"""
Train and eval functions used in main.py
"""
import math
import sys
from typing import Iterable, Optional

import torch

from timm.data import Mixup
from timm.utils import accuracy, ModelEma

from losses import DistillationLoss
import utils

import numpy as np
import os



from torchvision import transforms



import matplotlib.pyplot as plt
import numpy as np
from scipy.misc import face
from scipy.ndimage import zoom
from scipy.special import logsumexp
import torch

#import deepgaze_pytorch
import deepgaze_pytorch

DEVICE = 'cuda'

from saliency_models import gbvs, ittikochneibur
import cv2
import time
#from matplotlib import pyplot as plt

from numpy import asarray

inv_transform = transforms.Compose([
    transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1.0/0.229, 1.0/0.224, 1.0/0.225]),
    transforms.ToPILImage(mode='RGB'),
])

def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    set_training_mode=True, args = None):
    model.train(set_training_mode)
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10


    model_without_ddp = model
    if args.distributed:
        model_without_ddp = model.module

    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        model_without_ddp.target = targets
        model_without_ddp.save_images = False

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)
            
        if args.bce_loss:
            targets = targets.gt(0.0).type(targets.dtype)
                    
        with torch.cuda.amp.autocast():
            outputs = model(samples)
            #loss = criterion(samples, outputs, targets)
            loss = criterion(samples, outputs[:,:,0], targets)
            for fix_idx in np.arange(1, model_without_ddp.num_fixations): #1, model.num_fixations):
                #loss = 0.9*loss + criterion(samples, outputs[:,:,fix_idx], targets)
                loss += criterion(samples, outputs[:,:,fix_idx], targets)

        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        optimizer.zero_grad()

        # this attribute is added by timm on one optimizer (adahessian)
        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
        loss_scaler(loss, optimizer, clip_grad=max_norm,
                    parameters=model.parameters(), create_graph=is_second_order)

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


@torch.no_grad()
def evaluate(data_loader, model, device, epoch, args_distributed):
    criterion = torch.nn.CrossEntropyLoss()

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    # switch to evaluation mode
    model.eval()

    batch_idx = -1
    try:
        os.mkdir('saved_images/')
    except:
        pass

    model_without_ddp = model
    if args_distributed:
        model_without_ddp = model.module
    model_without_ddp.epoch = epoch
    try:
        os.mkdir('saved_images/'+str(epoch))
    except:
        pass


    STATS_DIR = 'Generated_results/Medium_sized_backbone/RG_Files/' #'Stats_Files/'
    try:
        os.mkdir(STATS_DIR)
    except:
        pass
    
    rescale = transforms.Resize(1024, interpolation=3)
    valid_subset = np.load('Generated_results/Medium_sized_backbone/RG_Files/valid_subet.npy')
    print(np.sum(valid_subset))

    model.time_taken = []

    # you can use DeepGazeI or DeepGazeIIE
    model_dg = deepgaze_pytorch.DeepGazeIIE(pretrained=True).to(DEVICE)
    centerbias_template = np.zeros((1024,1024))
    image = np.zeros((224,224))
    centerbias = zoom(centerbias_template, (image.shape[0]/centerbias_template.shape[0], image.shape[1]/centerbias_template.shape[1]), order=0, mode='nearest')
    centerbias -= logsumexp(centerbias)
    centerbias_tensor = torch.tensor([centerbias]).to(DEVICE)

    for images, target in metric_logger.log_every(data_loader, 100, header):
        B = images.shape[0]
        batch_idx+=1

        if not valid_subset[batch_idx]:
            metric_logger.update(loss=0)#loss.item())
            continue
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        img_idx = 0

        image_tensor = images #[img_idx].unsqueeze(0)
        image_tensor = rescale(image_tensor)

        img = images[0].cpu() #transpose(0,1).transpose(1,2).cpu() #img.permute(1,2,0).cpu()
        img = asarray(inv_transform(img))
        print(np.shape(img))
        #print(type(img))
        #print(img.size)
        #saliency_map_gbvs = gbvs.compute_saliency(img)
        #saliency_map_ikn = ittikochneibur.compute_saliency(img)
        #print(np.shape(saliency_map_gbvs))
        start_time = time.time()
        image_tensor = torch.tensor([img.transpose(2, 0, 1)]).to(DEVICE)
        log_density_prediction = model_dg(image_tensor, centerbias_tensor)
        log_density_prediction = log_density_prediction.cpu().numpy()
        print(log_density_prediction.shape)

        done_fix = np.zeros((224,224))
        init_x = 24 #110
        init_y = 24 #110
        left = np.maximum(0,init_x-24)
        top = np.maximum(0,init_y-24)
        right = np.minimum(224,init_x+25)
        bottom = np.minimum(224,init_y+25)
        done_fix[top:bottom, left:right] = 10000*np.ones((bottom-top, right-left))
        #fixation_history_x = np.append(np.random.randint(0,1024,[3]), 110)
        #fixation_history_y = np.append(np.random.randint(0,1024,[3]), 110)

        fix_list = []
        norm_x = np.floor((init_x/16))
        norm_y = np.floor((init_y/16))
        fix_list.append([norm_x, norm_y])

        for fix_idx in range(model_without_ddp.num_fixations-1):

            #loc = np.argmax(saliency_map_gbvs-done_fix)#log_density_prediction[0])
            #loc = np.argmax(saliency_map_ikn-done_fix)#log_density_prediction[0])
            loc = np.argmax(log_density_prediction-done_fix)#log_density_prediction[0])

            loc_y = loc//224
            loc_x = loc%224
            
            left = int(np.maximum(0,loc_x-24))
            top = int(np.maximum(0,loc_y-24))
            right = int(np.minimum(224,loc_x+25))
            bottom = int(np.minimum(224,loc_y+25))
            done_fix[top:bottom, left:right] = 10000*np.ones((bottom-top, right-left))
            
            #norm_x = int(np.floor((loc_x.cpu().numpy()/16)))
            #norm_y = (np.floor((loc_y.cpu().numpy()/16)))
            norm_x = int(np.floor((loc_x/16)))
            norm_y = int(np.floor((loc_y/16)))
            fix_list.append([norm_x, norm_y])

        print('fix_list: {}'.format(fix_list))
        model_without_ddp.fix_list = fix_list
        end_time = time.time()
        model.time_taken.append(end_time-start_time)
        print('Mean time taken: {}'.format(np.mean(model.time_taken)))






        model_without_ddp.batch_idx = batch_idx
        model_without_ddp.batch_cnt = batch_idx
        model_without_ddp.target = target
        model_without_ddp.save_images = False #True

        loss = 0
        #print('model: {}'.format(model))
        # compute output
        with torch.cuda.amp.autocast():
            output = model(images)
            #loss = criterion(output, target)
            loss = criterion(output[:,:,0], target)
            for fix_idx in np.arange(1, model_without_ddp.num_fixations):#1, model.num_fixations):
                #loss = 0.9*loss + criterion(output[:,:,fix_idx], target)
                loss += criterion(output[:,:,fix_idx], target)

        #acc1, acc5 = accuracy(output, target, topk=(1, 5))
        acc1_1, acc5_1 = accuracy(output[:,:,0], target, topk=(1, 5))
        acc1_2, acc5_2 = accuracy(output[:,:,1], target, topk=(1, 5))
        acc1_3, acc5_3 = accuracy(output[:,:,2], target, topk=(1, 5))
        acc1_4, acc5_4 = accuracy(output[:,:,3], target, topk=(1, 5))
        acc1_5, acc5_5 = accuracy(output[:,:,4], target, topk=(1, 5))

        batch_size = images.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.meters['acc1_1'].update(acc1_1.item(), n=batch_size)
        metric_logger.meters['acc1_2'].update(acc1_2.item(), n=batch_size)
        metric_logger.meters['acc1_3'].update(acc1_3.item(), n=batch_size)
        metric_logger.meters['acc1_4'].update(acc1_4.item(), n=batch_size)
        metric_logger.meters['acc1_5'].update(acc1_5.item(), n=batch_size)
        #metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)


    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print('* Fix@1 {top1.global_avg:.3f}'.format(top1=metric_logger.acc1_1))
    print('* Fix@2 {top1.global_avg:.3f}'.format(top1=metric_logger.acc1_2))
    print('* Fix@3 {top1.global_avg:.3f}'.format(top1=metric_logger.acc1_3))
    print('* Fix@4 {top1.global_avg:.3f}'.format(top1=metric_logger.acc1_4))
    print('* Fix@5 {top1.global_avg:.3f}'.format(top1=metric_logger.acc1_5))

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
