# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
"""
Train and eval functions used in main.py
"""
import math
import time
import sys
from typing import Iterable, Optional

import torch

from timm.data import Mixup
from timm.utils import accuracy, ModelEma

from losses import DistillationLoss
import utils


def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    set_training_mode=True):
    model.train(set_training_mode)
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)

        with torch.cuda.amp.autocast():
            outputs = model(samples)
            loss = criterion(samples, outputs, targets)

        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        optimizer.zero_grad()

        # this attribute is added by timm on one optimizer (adahessian)
        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
        loss_scaler(loss, optimizer, clip_grad=max_norm,
                    parameters=model.parameters(), create_graph=is_second_order)

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

def train_one_epoch_with_accumulation(model: torch.nn.Module, criterion: DistillationLoss,
                                        data_loader: Iterable, optimizer: torch.optim.Optimizer,
                                        device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                                        model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                                        set_training_mode=True, accumulation_steps: int = 2):
    model.train(set_training_mode)
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    optimizer.zero_grad()
    iter_count = 0

    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)

        with torch.cuda.amp.autocast():
            outputs = model(samples)
            loss = criterion(samples, outputs, targets)

        loss_value = loss.item()
        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        # 勾配累積のため、lossを累積ステップ数で割る（accumulation_steps==1なら元と同じ）
        loss = loss / accumulation_steps

        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order

        # NativeScaler の内部 _scaler を利用して逆伝搬
        loss_scaler._scaler.scale(loss).backward(create_graph=is_second_order)
        iter_count += 1

        if iter_count % accumulation_steps == 0:
            if max_norm > 0:
                # unscale_ が存在するか確認し、なければ手動で unscale 相当の処理を実施
                if hasattr(loss_scaler._scaler, 'unscale_'):
                    loss_scaler._scaler.unscale_(optimizer)
                else:
                    # 手動で unscale: 現在の scale を取得し、各パラメータの勾配を割る
                    scale = loss_scaler._scaler.get_scale()
                    for p in model.parameters():
                        if p.grad is not None:
                            p.grad.data.div_(scale)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            # AMP 対応の optimizer のステップと scaler の更新
            loss_scaler._scaler.step(optimizer)
            loss_scaler._scaler.update()
            optimizer.zero_grad()

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    if iter_count % accumulation_steps != 0:
        if max_norm > 0:
            if hasattr(loss_scaler._scaler, 'unscale_'):
                loss_scaler._scaler.unscale_(optimizer)
            else:
                scale = loss_scaler._scaler.get_scale()
                for p in model.parameters():
                    if p.grad is not None:
                        p.grad.data.div_(scale)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        loss_scaler._scaler.step(optimizer)
        loss_scaler._scaler.update()
        optimizer.zero_grad()

    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

@torch.no_grad()
def evaluate(data_loader, model, device):
    criterion = torch.nn.CrossEntropyLoss()

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    # switch to evaluation mode
    model.eval()

    for images, target in metric_logger.log_every(data_loader, 10, header):
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # compute output
        with torch.cuda.amp.autocast():
            output = model(images)
            loss = criterion(output, target)

        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        batch_size = images.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

@torch.no_grad()
def throughput(data_loader, model, device, times=30):
    model.eval()

    for idx, (images, _) in enumerate(data_loader):
        images = images.to(device, non_blocking=True)
        batch_size = images.shape[0]
        # for i in range(50):
        #     model(images)
        torch.cuda.synchronize()
        # logger.info(f"throughput averaged with 30 times")
        tic1 = time.time()
        for i in range(30):
            model(images)
        torch.cuda.synchronize()
        tic2 = time.time()
        return 30 * batch_size / (tic2 - tic1)