# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------

import math
import sys
from typing import Iterable, Optional

import torch

from timm.data import Mixup
from timm.utils import accuracy

import util.misc as misc
import util.lr_sched as lr_sched
from spikingjelly.clock_driven import functional


def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    mixup_fn: Optional[Mixup] = None, log_writer=None,
                    lr_scheduler = None, #2024/11/14 change optimizer and lr scheduler
                    args=None, top_down=False, beta=0.,   
                    ):
    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.8f}'))

    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 200

    num_updates = epoch * len(data_loader)   #ADD 2024/11/14

    accum_iter = args.accum_iter


    optimizer.zero_grad()

    #ADD
    # if optimizer_td is not None:
    #     optimizer_td.zero_grad()

    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))

    # log_every can print every freq while loading data_loader   Zizheng Zhu
    # it can print the metrics that using .update() in code
    for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        # print("aug is",model.module.aug)
        # we use a per iteration (instead of per epoch) lr scheduler
        # print(data_iter_step / len(data_loader) + epoch)
        # if data_iter_step % accum_iter == 0:
        #     lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
            
          

        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)
    


        with torch.cuda.amp.autocast():
            if top_down:

                # x1,td,tmp = model(samples)  #traditional
                # functional.reset_net(model)
                # x2 = model(tmp,td=td)
                
                # x1,td,tmp = model(samples)   # 结构1
                # x2 = model(samples,td = td)
                
                loss = 0
                td = None
                for i in range(args.time_step):
                    output,td = model(samples,td = None) if td is None else model(samples, td = td)
                    loss += criterion(output, targets)
                loss = loss / args.time_step
    

                if args.kd:
                    # loss = (1-beta)* criterion(samples,x1, targets)+ beta* criterion(samples,x2, targets)
                    pass
                else:
                    acc1 = accuracy(output, targets, topk=(1,))[0]
                    # loss = (1 - beta) * criterion(x1, targets) + beta * criterion(x2, targets)  # old
            
                    
         
            else:
                outputs = model(samples)
                loss = criterion(outputs, targets)

        loss_value = loss.item()



        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        loss = loss / accum_iter
        loss_scaler(loss, optimizer, clip_grad=max_norm,
                    parameters=model.parameters(), create_graph=False,
                    update_grad=(data_iter_step + 1) % accum_iter == 0)
        if (data_iter_step + 1) % accum_iter == 0:

            optimizer.zero_grad()

            #ADD
            # if optimizer_td is not None:
            #     optimizer_td.zero_grad()

        torch.cuda.synchronize()
        functional.reset_net(model)
        metric_logger.update(loss=loss_value)
        # metric_logger.update(aug= model.module.aug)
        min_lr = 10.
        max_lr = 0.
        for group in optimizer.param_groups:
            min_lr = min(min_lr, group["lr"])
            max_lr = max(max_lr, group["lr"])

        
        metric_logger.update(lr=max_lr)

        batch_size = samples.shape[0]
        # metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) #ADD 2024/11/20

        loss_value_reduce = misc.all_reduce_mean(loss_value)
        if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
            """ We use epoch_1000x as the x-axis in tensorboard.
            This calibrates different curves when batch size changes.
            """
            epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
            log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('lr', max_lr, epoch_1000x)


        if lr_scheduler is not None:  #ADD 2024/11/14
            lr_scheduler.step_update(num_updates=num_updates, metric=metric_logger.loss.avg)


        if hasattr(optimizer, 'sync_lookahead'): # ADD 2024/11/14
            optimizer.sync_lookahead()



    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


@torch.no_grad()
def evaluate(data_loader, model, device, args, td=False,test_td = False):
    criterion = torch.nn.CrossEntropyLoss()

    metric_logger = misc.MetricLogger(delimiter="  ")
    header = 'Test:'
    
    print_fre = 100

    # switch to evaluation mode
    model.eval()

    for batch in metric_logger.log_every(data_loader,print_fre, header):
        images = batch[0]
        target = batch[-1]
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # compute output
        with torch.cuda.amp.autocast():
            if td:
                # x1,td,tmp = model(images)  # traditional
                # functional.reset_net(model)
                # output = model(tmp,td=td)
                
                # x1,td,tmp = model(images)  # 2025/3/26
                # output = model(images,td = td)
                
                
                ########  NEW   ####### 
                                  
         
                loss = 0
                output_average = None  
                td = None
                for i in range(args.time_step):
                    output,td = model(images,td = None) if td is None else model(images, td = td)
              
                    
                    # output_average = output if output_average is None else output + output_average  #SDT
                    # output_average = output_average / args.time_step     #SDT
                    # output = output_average

                    loss += criterion(output, target)   # TET
                    
                loss = loss / args.time_step  # TET
                    
                    # output = output_average
                    # loss = criterion(output, target)  # SDT
                    
                    
                

            else:
                output = model(images)
            
            
            loss = criterion(output, target)

        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        functional.reset_net(model)

        batch_size = images.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
