# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import math
import sys
from typing import Iterable
import numpy
import torch
import pandas as pd
import util.misc as misc
import util.lr_sched as lr_sched


def train_one_epoch(model: torch.nn.Module, model_clone: torch.nn.Module, criterion: torch.nn.Module, method: str, mask_ratio: float,
                    data_loader_forget: Iterable, data_loader_retain: Iterable, data_loader_gaussian: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler,
                    log_writer=None,
                    args=None):
    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 20
    
    accum_iter = args.accum_iter
    
    data_list = []

    optimizer.zero_grad()

    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))

    for data_iter_step, (obj_forget, obj_retain, obj_gaussian) in enumerate(
        metric_logger.log_every(data_loader_forget, data_loader_retain, data_loader_gaussian, print_freq, header)):
        
        samples_forget, targets_forget = obj_forget
        samples_retain, targets_retain = obj_retain
        samples_gaussian, targets_gaussian = obj_gaussian
        # samples_gaussian = samples_gaussian.permute(0, 3, 1, 2)

        # we use a per iteration (instead of per epoch) lr scheduler
        if data_iter_step % accum_iter == 0:
            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_forget) + epoch, args)

        # samples = samples.to(device, non_blocking=True)
        samples_forget = samples_forget.to(device, non_blocking=True)
        targets_forget = targets_forget.to(device, non_blocking=True)

        samples_retain = samples_retain.to(device, non_blocking=True)
        targets_retain = targets_retain.to(device, non_blocking=True)
        
        samples_gaussian = samples_gaussian.to(device, non_blocking=True)
        targets_gaussian = targets_gaussian.to(device, non_blocking=True)
        
        if method == 'ours':
            def forget_closure():
                mode_1 = 'encoder'
                mode_2 = 'clone'
                with torch.cuda.amp.autocast():
                    with torch.enable_grad():
                        optimizer.zero_grad()
                        # output = model(samples_forget)
                        # loss = criterion(output, samples_forget)
                        if mode_1 == 'encoder':
                            latent_forget, mask_forget, ids_restore_forget, ids_keep_forget = model.forward_encoder(samples_forget, mask_ratio)
                            if mode_2 == 'clone':
                                with torch.no_grad():  
                                    latent_forget_clone, mask_forget_clone, ids_restore, ids_keep = model_clone.new_forward_encoder(
                                        samples_gaussian, mask_forget, ids_restore_forget, ids_keep_forget)
                                loss = criterion(latent_forget, latent_forget_clone)
                            else:
                                with torch.no_grad():  
                                    latent_forget_clone, mask_forget_clone, ids_restore, ids_keep = model.new_forward_encoder(
                                        samples_gaussian, mask_forget, ids_restore_forget, ids_keep_forget)
                                loss = criterion(latent_forget, latent_forget_clone)
                        else:
                            latent_forget, mask_forget, ids_restore_forget, ids_keep_forget = model.forward_encoder(samples_forget, mask_ratio)
                            pred_forget = model.forward_decoder(latent_forget, ids_restore_forget) 
                            if mode_2 == 'clone':
                                with torch.no_grad():  
                                    latent_forget_clone, mask_forget_clone, ids_restore_forget_clone, ids_keep = model_clone.new_forward_encoder(
                                        samples_gaussian, mask_forget, ids_restore_forget, ids_keep_forget)
                                    pred_forget_clone = model_clone.forward_decoder(latent_forget_clone, ids_restore_forget_clone)
                                loss = model.forward_loss(pred_forget_clone, pred_forget, mask_forget)
                            else:
                                '''
                                with torch.no_grad():  
                                    latent_forget_clone, mask_forget_clone, ids_restore_forget_clone, ids_keep = model.new_forward_encoder(
                                    samples_gaussian, mask_forget, ids_restore_forget, ids_keep_forget)
                                    pred_forget_clone = model.forward_decoder(latent_forget_clone, ids_restore_forget_clone)
                                '''
                                loss = model.forward_loss(samples_gaussian, pred_forget, mask_forget)
                                
                if not math.isfinite(loss):
                    print("Loss is {}, stopping training".format(loss))
                    sys.exit(1)

                loss /= accum_iter
                loss.backward()
                return loss
            
            def retain_closure():
                mode_1 = 'encoder'
                mode_2 = 'clone'
                with torch.cuda.amp.autocast():
                    with torch.enable_grad():
                        optimizer.zero_grad()
                        if mode_1 == 'encoder':
                            latent_retain, mask_retain, ids_restore_retain, ids_keep_retain = model.forward_encoder(samples_retain, mask_ratio)
                            if mode_2 == 'clone':
                                with torch.no_grad(): 
                                    latent_retain_clone, mask_retain_clone, ids_restore, ids_keep = model_clone.new_forward_encoder(
                                        samples_retain, mask_retain, ids_restore_retain, ids_keep_retain)
                                loss = criterion(latent_retain, latent_retain_clone)
                            else:
                                with torch.no_grad(): 
                                    latent_retain_clone, mask_retain_clone, ids_restore, ids_keep = model.new_forward_encoder(
                                        samples_retain, mask_retain, ids_restore_retain, ids_keep_retain)
                                loss = criterion(latent_retain, latent_retain_clone)
                        else:
                            latent_retain, mask_retain, ids_restore_retain, ids_keep_retain = model.forward_encoder(samples_retain, mask_ratio)
                            pred_retain = model.forward_decoder(latent_retain, ids_restore_retain) 
                            loss = model.forward_loss(samples_retain, pred_retain, mask_retain)
                                
                if not math.isfinite(loss):
                    print("Loss is {}, stopping training".format(loss))
                    sys.exit(1)

                loss /= accum_iter
                loss.backward()
                return loss
            
            optimizer.zero_grad()
            
            mode_a = 'one'
            mode_b = 'f'
            if mode_a == 'one':
                optimizer.beta = 5
                if mode_b == 'f':
                    loss_forget, loss_retain = optimizer.step(
                        forget_closure=forget_closure, retain_closure=retain_closure, 
                        mode = 'one', g_constraint=0.1)
                else:
                    loss_forget, loss_retain = optimizer.step(
                        forget_closure=retain_closure, retain_closure=forget_closure, 
                        mode = 'one', g_constraint=0.1)
            elif mode_a == 'all':
                optimizer.alpha = 5
                loss_forget, loss_retain = optimizer.step(
                    forget_closure=forget_closure, retain_closure=retain_closure, 
                    mode = 'all', g_constraint=0.60)
                
        elif method in ['max_loss', 'retain_label','noisy_label']:
            with torch.cuda.amp.autocast():
                with torch.enable_grad():
                    optimizer.zero_grad()
                    _, pred_forget, mask_forget = model(samples_forget, mask_ratio)
                    #pred_forget = model.unpatchify(pred_forget)
                    #pred_forget = torch.einsum('nchw->nhwc', pred_forget)
                    
                    _, pred_retain, mask_retain = model(samples_retain, mask_ratio)
                    #pred_retain = model.unpatchify(pred_retain)
                    #pred_retain = torch.einsum('nchw->nhwc', pred_retain)
                    
                    if method == 'max_loss':
                        loss_forget = model.forward_loss(samples_forget, pred_forget, mask_forget)
                    elif method == 'retain_label':
                        loss_forget = model.forward_loss(samples_retain, pred_forget, mask_forget)
                    elif method == 'noisy_label':
                        loss_forget = model.forward_loss(samples_gaussian, pred_forget, mask_forget)
                    loss_retain = model.forward_loss(samples_retain, pred_retain, mask_retain)
                    
                    if method == 'max_loss':
                        loss = loss_retain - loss_forget * 0.25
                    elif method in ['retain_label', 'noisy_label']:
                        loss = loss_retain + loss_forget * 0.25
                    
                    if not math.isfinite(loss):
                        print("Loss is {}, stopping training".format(loss))
                        sys.exit(1)

                loss /= accum_iter
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        elif method == 'iclr':
            with torch.cuda.amp.autocast():
                with torch.enable_grad():
                    optimizer.zero_grad()
                    latent_forget, mask_forget, ids_restore_forget, ids_keep_forget = model.forward_encoder(samples_forget, mask_ratio)
                    latent_retain, mask_retain, ids_restore_retain, ids_keep_retain = model.forward_encoder(samples_retain, mask_ratio)
                    with torch.no_grad():  
                        latent_forget_clone, mask_forget_clone, ids_restore, ids_keep = model_clone.new_forward_encoder(
                            samples_gaussian, mask_forget, ids_restore_forget, ids_keep_forget)
                        latent_retain_clone, mask_retain_clone, ids_restore, ids_keep = model_clone.new_forward_encoder(
                            samples_retain, mask_retain, ids_restore_retain, ids_keep_retain)
                    
                    # print(latent_forget.shape, latent_forget_clone.shape)
                    loss_forget = criterion(latent_forget, latent_forget_clone)
                    loss_retain = criterion(latent_retain, latent_retain_clone)

                    loss = loss_retain + loss_forget * 1
                    
                    if not math.isfinite(loss):
                        print("Loss is {}, stopping training".format(loss))
                        sys.exit(1)

                loss /= accum_iter
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        if (data_iter_step + 1) % accum_iter == 0:
            optimizer.zero_grad()
        
        torch.cuda.synchronize()
        
        if data_iter_step % 4 == 0:
            data_list.append({'f1': loss_forget.detach().cpu().numpy(), 'f2': loss_retain.detach().cpu().numpy()})

        metric_logger.update(loss_forget=loss_forget)
        metric_logger.update(loss_retain=loss_retain)

        lr = optimizer.param_groups[0]["lr"]
        metric_logger.update(lr=lr)

        loss_forget_reduce = misc.all_reduce_mean(loss_forget)
        loss_retain_reduce = misc.all_reduce_mean(loss_retain)
        
        if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
            """ We use epoch_1000x as the x-axis in tensorboard.
            This calibrates different curves when batch size changes.
            """
            epoch_1000x = int((data_iter_step / len(data_loader_forget) + epoch) * 1000)
            log_writer.add_scalar('loss_forget', loss_forget_reduce, epoch_1000x)
            log_writer.add_scalar('loss_retain', loss_retain_reduce, epoch_1000x)
            log_writer.add_scalar('lr', lr, epoch_1000x)


    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    
    df_loss = pd.DataFrame(data_list)
    df_loss.to_csv(f'experiments/{epoch}.csv', index=False)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}