import math
import sys
from typing import Iterable
import numpy
import torch
import pandas as pd
import util.misc as misc
import util.lr_sched as lr_sched


def train_one_epoch(model: torch.nn.Module, model_clone: torch.nn.Module, criterion: torch.nn.Module, method: str, mask_ratio: float,
                    data_loader_forget: Iterable, data_loader_retain: Iterable, data_loader_gaussian: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler,
                    log_writer=None,
                    args=None):
    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 20

    accum_iter = args.accum_iter
    
    data_list = []

    optimizer.zero_grad()

    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))

    # 注：iclr 原文中的损失函数是 L2 范数，也就是两个向量的欧几里得距离，这里我们统一用 MSE（两个向量的 L2 范数的平方再求平均） 来代替
    for data_iter_step, (obj_forget, obj_retain, obj_gaussian) in enumerate(
        metric_logger.log_every(data_loader_forget, data_loader_retain, data_loader_gaussian, print_freq, header)):
        
        samples_forget, targets_forget = obj_forget
        samples_retain, targets_retain = obj_retain
        samples_gaussian, targets_gaussian = obj_gaussian
        # samples_gaussian = samples_gaussian.permute(0, 3, 1, 2)

        # we use a per iteration (instead of per epoch) lr scheduler
        if data_iter_step % accum_iter == 0:
            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_forget) + epoch, args)

        # samples = samples.to(device, non_blocking=True)
        samples_forget = samples_forget.to(device, non_blocking=True)
        targets_forget = targets_forget.to(device, non_blocking=True)

        samples_retain = samples_retain.to(device, non_blocking=True)
        targets_retain = targets_retain.to(device, non_blocking=True)
        
        samples_gaussian = samples_gaussian.to(device, non_blocking=True)
        targets_gaussian = targets_gaussian.to(device, non_blocking=True)

        if method == 'ours':
            def forget_closure():
                mode_1 = '1encoder'
                mode_2 = 'clone'
                with torch.cuda.amp.autocast():
                    with torch.enable_grad():
                        optimizer.zero_grad()
                        if mode_1 == 'encoder':
                            # latent_forget, _, _, _ = model.forward_encoder(samples_forget, mask_ratio)
                            latent_forget, gt_indices, token_drop_mask, token_all_mask = model.forward_encoder(samples_forget, mask_ratio)
                            if mode_2 == 'clone':
                                with torch.no_grad():  # 确保不计算梯度
                                    # latent_forget_clone, _, _, _ = model_clone.forward_encoder(samples_gaussian, mask_ratio)
                                    latent_forget_clone, _, _, _ = model_clone.new_forward_encoder(samples_gaussian, token_drop_mask, token_all_mask)
                                loss = criterion(latent_forget, latent_forget_clone)
                            else:
                                with torch.no_grad():  # 确保不计算梯度
                                    latent_forget_clone, _, _, _ = model_clone.new_forward_encoder(samples_gaussian, token_drop_mask, token_all_mask)
                                loss = criterion(latent_forget, latent_forget_clone)
                        else:
                            latent_forget, gt_indices_forget, token_drop_mask, token_all_mask = model.forward_encoder(samples_forget, mask_ratio)
                            logits = model.forward_decoder(latent_forget, token_drop_mask, token_all_mask)
                            if mode_2 == 'clone':
                                with torch.no_grad():  # 确保不计算梯度
                                    latent_forget_clone, gt_indices_clone, _, _ = model_clone.new_forward_encoder(samples_gaussian, token_drop_mask, token_all_mask)
                                loss = model.forward_loss(gt_indices_clone, logits, token_all_mask)
                            else:
                                with torch.no_grad():  # 确保不计算梯度
                                    latent_forget_clone, gt_indices_clone, _, _ = model.new_forward_encoder(samples_gaussian, token_drop_mask, token_all_mask)
                                loss = model.forward_loss(gt_indices_clone, logits, token_all_mask)
                                
                if not math.isfinite(loss):
                    print("Loss is {}, stopping training".format(loss))
                    sys.exit(1)

                loss /= accum_iter
                loss.backward()
                return loss
            
            def retain_closure():
                mode_1 = '1encoder'
                mode_2 = 'clone'
                with torch.cuda.amp.autocast():
                    with torch.enable_grad():
                        optimizer.zero_grad()
                        if mode_1 == 'encoder':
                            latent_retain, gt_indices, token_drop_mask, token_all_mask = model.forward_encoder(samples_retain, mask_ratio)
                            if mode_2 == 'clone':
                                with torch.no_grad():  # 确保不计算梯度
                                    # latent_retain_clone, _, _, _ = model_clone.forward_encoder(samples_retain, mask_ratio)
                                    latent_retain_clone, _, _, _ = model_clone.new_forward_encoder(samples_retain, token_drop_mask, token_all_mask)
                                loss = criterion(latent_retain, latent_retain_clone)
                            else:
                                with torch.no_grad():  # 确保不计算梯度
                                    # latent_retain_clone, _, _, _ = model.forward_encoder(samples_retain, mask_ratio)
                                    latent_retain_clone, _, _, _ = model.new_forward_encoder(samples_retain, token_drop_mask, token_all_mask)
                                loss = criterion(latent_retain, latent_retain_clone)
                        else:
                            latent_retain, gt_indices, token_drop_mask, token_all_mask = model.forward_encoder(samples_retain, mask_ratio)
                            logits = model.forward_decoder(latent_retain, token_drop_mask, token_all_mask)
                            with torch.no_grad():  # 确保不计算梯度
                                latent_retain_clone, gt_indices_clone, _, _ = model_clone.new_forward_encoder(samples_retain, token_drop_mask, token_all_mask)
                            loss = model.forward_loss(gt_indices_clone, logits, token_all_mask)
                                
                if not math.isfinite(loss):
                    print("Loss is {}, stopping training".format(loss))
                    sys.exit(1)

                loss /= accum_iter
                loss.backward()
                return loss
            
            optimizer.zero_grad()
            
            mode_a = 'all'
            mode_b = 'f'
            if mode_a == 'one':
                if mode_b == 'f':
                    loss_forget, loss_retain = optimizer.step(
                        forget_closure=forget_closure, retain_closure=retain_closure, 
                        mode = 'one', g_constraint=0.1)
                else:
                    loss_forget, loss_retain = optimizer.step(
                        forget_closure=retain_closure, retain_closure=forget_closure, 
                        mode = 'one', g_constraint=0.1)
            elif mode_a == 'all':
                loss_forget, loss_retain = optimizer.step(
                    forget_closure=forget_closure, retain_closure=retain_closure, 
                    mode = 'all', g_constraint=5.8)
        
        elif method in ['max_loss', 'retain_label','noise_label']:
            with torch.cuda.amp.autocast():
                with torch.enable_grad():
                    optimizer.zero_grad()
                    latent_forget, _, token_drop_mask_forget, token_all_mask_forget = model.forward_encoder(samples_forget, mask_ratio)
                    logits_forget = model.forward_decoder(latent_forget, token_drop_mask_forget, token_all_mask_forget)
                    with torch.no_grad():  # 确保不计算梯度
                        _, gt_indices_clone_f_g, _, _ = model_clone.new_forward_encoder(samples_gaussian, token_drop_mask_forget, token_all_mask_forget)
                        _, gt_indices_clone_f_r, _, _ = model_clone.new_forward_encoder(samples_retain, token_drop_mask_forget, token_all_mask_forget)
                        _, gt_indices_clone_f_f, _, _ = model_clone.new_forward_encoder(samples_forget, token_drop_mask_forget, token_all_mask_forget)
                            
                    latent_retain, _, token_drop_mask_retain, token_all_mask_retain = model.forward_encoder(samples_retain, mask_ratio)
                    logits_retain = model.forward_decoder(latent_retain, token_drop_mask_retain, token_all_mask_retain)
                    with torch.no_grad():  # 确保不计算梯度
                        _, gt_indices_clone_r, _, _ = model_clone.new_forward_encoder(samples_retain, token_drop_mask_retain, token_all_mask_retain)
                            
                    if method == 'max_loss':
                        loss_forget = model.forward_loss(gt_indices_clone_f_f, logits_forget, token_all_mask_forget)
                    elif method == 'retain_label':
                        loss_forget = model.forward_loss(gt_indices_clone_f_r, logits_forget, token_all_mask_forget)
                    elif method == 'noise_label':
                        loss_forget = model.forward_loss(gt_indices_clone_f_g, logits_forget, token_all_mask_forget)
                        
                    loss_retain = model.forward_loss(gt_indices_clone_r, logits_retain, token_all_mask_retain)
                    
                    if method == 'max_loss':
                        loss = loss_retain - loss_forget * 0.25
                    elif method == 'retain_label':
                        loss = loss_retain + loss_forget * 0.25
                    elif method == 'noise_label':
                        loss = loss_retain + loss_forget * 0.25
                    
                    if not math.isfinite(loss):
                        print("Loss is {}, stopping training".format(loss))
                        sys.exit(1)

                loss /= accum_iter
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        elif method == 'iclr':
            with torch.cuda.amp.autocast():
                with torch.enable_grad():
                    optimizer.zero_grad()
                    latent_forget, gt_indices, token_drop_mask_forget, token_all_mask_forget = model.forward_encoder(samples_forget, mask_ratio)
                    latent_retain, gt_indices, token_drop_mask_retain, token_all_mask_retain = model.forward_encoder(samples_retain, mask_ratio)
                    with torch.no_grad():  # 确保不计算梯度
                        latent_forget_clone, _, _, _ = model_clone.new_forward_encoder(samples_gaussian, token_drop_mask_forget, token_all_mask_forget)
                        latent_retain_clone, _, _, _ = model_clone.new_forward_encoder(samples_retain, token_drop_mask_retain, token_all_mask_retain)
                    
                    # print(latent_forget.shape, latent_forget_clone.shape)
                    loss_forget = criterion(latent_forget, latent_forget_clone)
                    loss_retain = criterion(latent_retain, latent_retain_clone)

                    loss = loss_retain + loss_forget * 0.25
                    
                    if not math.isfinite(loss):
                        print("Loss is {}, stopping training".format(loss))
                        sys.exit(1)

                loss /= accum_iter
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        
        if (data_iter_step + 1) % accum_iter == 0:
            optimizer.zero_grad()
        
        torch.cuda.synchronize()
        
        if data_iter_step % 24 == 0:
            data_list.append({'f1': loss_forget.detach().cpu().numpy(), 'f2': loss_retain.detach().cpu().numpy()})

        metric_logger.update(loss_forget=loss_forget)
        metric_logger.update(loss_retain=loss_retain)

        lr = optimizer.param_groups[0]["lr"]
        metric_logger.update(lr=lr)

        loss_forget_reduce = misc.all_reduce_mean(loss_forget)
        loss_retain_reduce = misc.all_reduce_mean(loss_retain)

        if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
            """ We use epoch_1000x as the x-axis in tensorboard.
            This calibrates different curves when batch size changes.
            """
            epoch_1000x = int((data_iter_step / len(data_loader_forget) + epoch) * 1000)
            log_writer.add_scalar('loss_forget', loss_forget_reduce, epoch_1000x)
            log_writer.add_scalar('loss_retain', loss_retain_reduce, epoch_1000x)
            log_writer.add_scalar('lr', lr, epoch_1000x)


    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    
    df_loss = pd.DataFrame(data_list)
    df_loss.to_csv(f'experiments/{epoch}.csv', index=False)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}