# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
"""
Train and eval functions used in main.py
"""
import math
import sys
from typing import Iterable, Optional
# from fvcore.nn import FlopCountAnalysis


import torch
import torch.nn.functional as F

from timm.data import Mixup
from timm.utils import accuracy, ModelEma
from collections import defaultdict
import utils


def train_one_epoch(model: torch.nn.Module, model_org: torch.nn.Module, criterion: F.cross_entropy,
                    data_loader: Iterable, device: torch.device, epoch: int, 
                    loss_scaler=None, optimizer=None, emsa_optimizer=None, layerwise = -1, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    set_training_mode=True, wandb=None):
    model.train(set_training_mode)  # Set the model in training mode

    # Initialize metric logger
    metric_logger = utils.MetricLogger(delimiter="  ")

    # Header for printing the logs
    header = 'Epoch: [{}]'.format(epoch)

    # Frequency of printing the logs
    print_freq = 10
    total_loss = 0.
    all_train_corrects = []
    magnitudes = defaultdict(float)


    # Loop over all data in data_loader
    for i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header,
                                                                   wandb=wandb, epoch=epoch)):
        
        # if i >195:
        #     break
        # Move samples and targets to the given device
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        


        # targets_org = targets.detach().clone()

        # Apply mixup if provided
        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)
        # Allow gradient calulation for input
        samples.requires_grad = True

        # Use automated mixed precision for model's forward pass
        with torch.autograd.set_detect_anomaly(True):
            # with torch.cuda.amp.autocast():
            # outputs, x_array = model(samples)
            # outputs_org = model_org(samples)
            # outputs = emsa.compute_x(model, samples)
            # Compute the loss
            outputs, x_array = emsa_optimizer.compute_x(samples)

            loss = criterion(outputs, targets)
            if emsa_optimizer is not None:
                p_list = emsa_optimizer.compute_p(loss, x_array)

        # Get the value of the loss
        loss_value = loss.item()

        # If loss is not finite, stop training
        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        if emsa_optimizer is not None:
            # Update model parameters with E-MSA rule
            with torch.autograd.set_detect_anomaly(True):
                emsa_optimizer.update_model(x_array, p_list, model, layerwise, wandb = wandb)

        # Synchronize CUDA operations
        torch.cuda.synchronize()

        # Update the model's exponential moving average if provided
        if model_ema is not None:
            model_ema.update(model)
        else:
            if optimizer is not None:
                # Reset gradients of all parameters
                optimizer.zero_grad()

                # Check if optimizer is second order (e.g., adahessian)
                # this attribute is added by timm on one optimizer (adahessian)
                is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order

                # Scale the loss and update model parameters
                loss_scaler(loss, optimizer, clip_grad=max_norm,
                            parameters=model.parameters(), create_graph=is_second_order)
                # loss.backward()
                # optimizer.step()

        # Update the loss and learning rate in the metric logger
        with torch.no_grad():
            outputs = model(samples)
            loss = criterion(outputs, targets)
            loss_value = loss.item()
            all_train_corrects.append(torch.argmax(outputs, dim=-1) == targets)
            total_loss += loss
            metric_logger.update(loss=loss_value)
        # wandb.log({"train_loss": metric_logger.loss.value})
        # break
        # metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        torch.cuda.empty_cache()

    acc = torch.cat(all_train_corrects).float().mean().detach().item()
    total_loss = total_loss / len(data_loader)
    total_loss = total_loss.detach().item()
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    

    # Print the average metrics
    print("Averaged stats:", metric_logger)

    # Return the average metrics
    return acc, total_loss, magnitudes
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}



def layer_select(model: torch.nn.Module, criterion: F.cross_entropy,
                    data_loader: Iterable, device: torch.device, epoch: int, 
                    optimizer=None, emsa_optimizer=None, 
                    wandb=None, LMSA = False):
    # model.train(set_training_mode)  # Set the model in training mode

    # Initialize metric logger
    metric_logger = utils.MetricLogger(delimiter="  ")

    # Header for printing the logs
    header = 'Epoch: [{}]'.format(epoch)

    # Frequency of printing the logs
    print_freq = 10
    l_list = []
    n_batch=0

    # Loop over all data in data_loader
    for i_, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header,
                                                                   wandb=None, epoch=epoch)):
        
        # Move samples and targets to the given device
        samples = samples.to(device, non_blocking=True)
        print(samples.shape)
        targets = targets.to(device, non_blocking=True)

        # Allow gradient calulation for input
        samples.requires_grad = True

        # Use automated mixed precision for model's forward pass
        with torch.autograd.set_detect_anomaly(True):
            # with torch.cuda.amp.autocast():
            outputs, x_array = emsa_optimizer.compute_x(samples)
            # Compute the loss

            loss = criterion(outputs, targets)
            if emsa_optimizer is not None:
                p_list = emsa_optimizer.compute_p(loss, x_array)

        # Get the value of the loss
        loss_value = loss.item()

        # If loss is not finite, stop training
        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        if emsa_optimizer is not None:
            # Update model parameters with E-MSA rule
            with torch.autograd.set_detect_anomaly(True):
                # emsa_optimizer.update_model(x_array, p_list, model, wandb = wandb)


                optimizer_list = emsa_optimizer.create_optimizers(model, -1)

                for i, optimizer in enumerate(optimizer_list):

                    # if i == len(optimizer_list)-1:
                    #     break
                    if optimizer is None:
                        if i_ == 0:
                            l_list.append(0)
                        continue

                    # Zero grad all parameters
                    for optimizer_temp in optimizer_list:
                        if optimizer_temp is not None:
                            optimizer_temp.zero_grad()

                    # Enable grad of the current layer
                    for param_group in optimizer.param_groups:
                        for param in param_group["params"]:
                            param.requires_grad = True

                    # Disable grad of other layers
                    for optimizer_temp in optimizer_list:
                        if optimizer_temp is not optimizer and optimizer_temp is not None:
                            for param_group in optimizer_temp.param_groups:
                                for param in param_group["params"]:
                                    param.requires_grad = False

                    all_blocks = emsa_optimizer.all_blocks
                    if len(all_blocks) == 7:
                        num_class = 17
                    else:
                        num_class = 10
                    optimizer.zero_grad()
                    x = x_array[i].detach().clone()
                    p = p_list[i+1].detach().clone()
                    # 
                    if LMSA :
                        xxT = emsa_optimizer.multiply(x, x)
                        ppT = emsa_optimizer.multiply(p, p)
                        targets_one_hot = torch.nn.functional.one_hot(targets.long(), num_classes=outputs.shape[-1])
                        # loss_l2 = ((outputs-targets_one_hot).pow(2).sum()/2).detach().clone()
                        loss_l2 = torch.mean((outputs - targets_one_hot) ** 2).detach().clone()/10
                        # loss_value = max(loss_value,8)
                        # r1 = torch.pow(torch.norm(emsa_optimizer.multiply(xxT, ppT)), 2)/loss_value
                        r1 = torch.pow(torch.norm(emsa_optimizer.multiply(xxT, ppT)), 2)/loss_l2.detach().clone()*2
                        r2 = torch.trace(emsa_optimizer.multiply(xxT, ppT))
                        lr = r2/r1* num_class

                        lr = min(lr, 10)
                        # print(r2, r1, lr)
                        if len(p.shape) == 4:
                            x_ = lr*torch.einsum("ab, bcde-> acde", xxT, p)+x_array[i+1].detach().clone()
                        elif len(p.shape) == 2:
                            x_ = lr*torch.einsum("ab, bc-> ac", xxT, p)+x_array[i+1].detach().clone()
                        
                        for j, block in enumerate(all_blocks):
                            if j>i:
                                x_ = block(x_)
                                if j == len(all_blocks) - 2:
                                    if len(all_blocks) == 6:
                                        x_ = torch.nn.functional.avg_pool2d(x_, 8)
                                        x_ = x_.view(-1, 640)
                                    else:
                                        x_ =  torch.flatten(x_, 1)
                                
                        metric = criterion(x_, targets)
                    else:
                        H, g = emsa_optimizer.compute_H(x, p, all_blocks[i])
                    # r = self.compute_r(x, p_list[i+1].detach().clone())
                        H.sum().backward()
                        theta_L2 = torch.tensor(0.)
                        grad_L2 = torch.tensor(0.)
                        theta_L2 = theta_L2.to(device, non_blocking=True)
                        grad_L2 = grad_L2.to(device, non_blocking=True)
                        for param_group in optimizer.param_groups:
                            for param in param_group["params"]:
                                theta_L2 += torch.norm(param, p=2)
                                grad_L2 += torch.norm(param.grad, p=2)
                        

                        metric = grad_L2 / theta_L2                    
                    # print(i, metric.item())
                    if i_ == 0:
                        l_list.append(metric)
                    else:
                        l_list[i] += metric
            n_batch = i_+1
        # break
                
    print([loss/n_batch for loss in l_list])


@torch.no_grad()
def evaluate(data_loader, model, device, attn_only=False, batch_limit=0, wandb=None, epoch=None):
    criterion = torch.nn.CrossEntropyLoss()


    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'


    # switch to evaluation mode
    model.eval()
    # i = 0
    if not isinstance(batch_limit, int) or batch_limit < 0:
        batch_limit = 0
    attn = []
    pi = []
    for i, (images, target) in enumerate(metric_logger.log_every(data_loader, 10, header, wandb=wandb, is_test=True)):
        if i >= batch_limit > 0:
            break
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)
        

    

        with torch.cuda.amp.autocast():
            if attn_only:
                output, _aux = model(images)
                attn.append(_aux[0].detach().cpu().numpy())
                pi.append(_aux[1].detach().cpu().numpy())
                del _aux
            else:
                output, x_array = model(images)
            loss = criterion(output, target)


        acc1, acc5 = accuracy(output, target, topk=(1, 5))


        batch_size = images.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))

    # wandb.log({"test_acc1": metric_logger.acc1.global_avg,
    #         "test_acc5": metric_logger.acc5.global_avg,
    #         "test_loss": metric_logger.loss.global_avg, 
    #         "epoch": epoch+1})

    r = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
    if attn_only:
        return r, (attn, pi)
    return r
