# --------------------------------------------------------
# References:
# MAE: https://github.com/facebookresearch/mae
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------

import math
import sys
from typing import Iterable

import torch
import torch.nn.functional as F
import numpy as np

import util.misc as misc
import util.lr_sched as lr_sched
from util.criterion import SimpleLpLoss


def train_one_epoch(model: torch.nn.Module, discriminator: torch.nn.Module, criterion: torch.nn.Module,
                    data_loader: Iterable, opt_dis: torch.optim.Optimizer, opt_vae: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    log_writer=None, args=None):
    discriminator.train(True)
    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 20
    
    if args.weight_dtype == "fp16":
        weight_dtype = torch.float16
    elif args.weight_dtype == "bf16":
        weight_dtype = torch.bfloat16
    elif args.weight_dtype == "fp32":
        weight_dtype = torch.float32

    accum_iter = args.accum_iter

    opt_dis.zero_grad()
    opt_vae.zero_grad()

    kl_weight = args.kl_weight

    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))

    # for data_iter_step, (points, labels, surface, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
    for data_iter_step, (points, labels, body, cls) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        
        # set adversarial weight
        # adv_weight = 0.3
        p = float(data_iter_step + epoch * len(data_loader))/args.epochs/len(data_loader)
        dis_weight = args.dis_factor*(2./(1. + np.exp(-10*p)) - 1)
        adv_weight = args.adv_factor*(2./(1. + np.exp(-10*p)) - 1)

        # we use a per iteration (instead of per epoch) lr scheduler
        if data_iter_step % accum_iter == 0:
            lr_sched.adjust_learning_rate(opt_dis, data_iter_step / len(data_loader) + epoch, args)
            lr_sched.adjust_learning_rate(opt_vae, data_iter_step / len(data_loader) + epoch, args)

        points = points.to(device, non_blocking=True, dtype=weight_dtype)
        labels = labels.to(device, non_blocking=True, dtype=weight_dtype)
        body = body.to(device, non_blocking=True, dtype=weight_dtype)
        cls = cls.to(device, non_blocking=True, dtype=weight_dtype)

        # Train feature discriminator
        with torch.amp.autocast('cuda', enabled=False):
            
            latent = model.module.encode(body, return_kl=False).detach()
            cls_logit = discriminator(latent)
            cls_loss = dis_weight * criterion(cls_logit, cls)
        
        loss_scaler(cls_loss, opt_dis, clip_grad=max_norm,
                    parameters=discriminator.parameters(), create_graph=False,
                    update_grad=(data_iter_step + 1) % accum_iter == 0)
        opt_dis.zero_grad()
            
        # Train VAE
        with torch.amp.autocast('cuda', enabled=False):
            outputs = model(points, body)
            if 'kl' in outputs:
                loss_kl = outputs['kl']
                loss_kl = torch.sum(loss_kl) / loss_kl.shape[0]
            else:
                loss_kl = None

            outputs = outputs['logits']
            loss_vol = criterion(outputs[:, :args.num_queries], labels[:, :args.num_queries])
            loss_near = criterion(outputs[:, args.num_queries:], labels[:, args.num_queries:])
            
            latent = model.module.encode(body, return_kl=False)
            cls_logit = discriminator(latent)
            cls_loss = criterion(cls_logit, 1 - cls)
            
            if loss_kl is not None:
                loss = loss_vol + 0.1 * loss_near + kl_weight * loss_kl + adv_weight * cls_loss
            else:
                loss = loss_vol + 0.1 * loss_near + adv_weight * cls_loss

        loss_value = loss.item()

        # Compute iou of occupancy
        threshold = 0
        pred = torch.zeros_like(outputs[:, :args.num_queries])
        pred[outputs[:, :args.num_queries]>=threshold] = 1

        accuracy = (pred==labels[:, :args.num_queries]).float().sum(dim=1) / labels[:, :args.num_queries].shape[1]
        accuracy = accuracy.mean()
        intersection = (pred * labels[:, :args.num_queries]).sum(dim=1)
        union = (pred + labels[:, :args.num_queries]).gt(0).sum(dim=1) + 1e-5
        iou = intersection * 1.0 / union
        iou = iou.mean()
        
        # Compute accuracy of discriminator
        cls_pred = torch.zeros_like(cls_logit)
        cls_pred[cls_logit>=threshold] = 1
        
        cls_accuracy = (cls_pred==cls).float().mean()
        

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        loss /= accum_iter
        loss_scaler(loss, opt_vae, clip_grad=max_norm,
                    parameters=model.parameters(), create_graph=False,
                    update_grad=(data_iter_step + 1) % accum_iter == 0)
        if (data_iter_step + 1) % accum_iter == 0:
            opt_vae.zero_grad()
            opt_dis.zero_grad()

        torch.cuda.synchronize()

        metric_logger.update(loss=loss_value)
        # metric_logger.update(loss_vol=loss_vol.item())
        # metric_logger.update(loss_near=loss_near.item())
        metric_logger.update(cls_accuracy=cls_accuracy.item())

        if loss_kl is not None:
            metric_logger.update(loss_kl=loss_kl.item())

        metric_logger.update(iou=iou.item())

        min_lr = 10.
        max_lr = 0.
        for group in opt_vae.param_groups:
            min_lr = min(min_lr, group["lr"])
            max_lr = max(max_lr, group["lr"])

        metric_logger.update(lr=max_lr)

        loss_value_reduce = misc.all_reduce_mean(loss_value)
        if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
            """ We use epoch_1000x as the x-axis in tensorboard.
            This calibrates different curves when batch size changes.
            """
            epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
            log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('lr', max_lr, epoch_1000x)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


@torch.no_grad()
def evaluate(data_loader, model, discriminator, device, args):
    criterion = torch.nn.BCEWithLogitsLoss()
    # criterion = SimpleLpLoss(size_average=True)

    metric_logger = misc.MetricLogger(delimiter="  ")
    header = 'Test:'
    
    if args.weight_dtype == "fp16":
        weight_dtype = torch.float16
    elif args.weight_dtype == "bf16":
        weight_dtype = torch.bfloat16
    elif args.weight_dtype == "fp32":
        weight_dtype = torch.float32

    # switch to evaluation mode
    model.eval()
    discriminator.eval()

    # for points, labels, surface, _ in metric_logger.log_every(data_loader, 50, header):
    for points, labels, body, cls in metric_logger.log_every(data_loader, 50, header):

        points = points.to(device, non_blocking=True, dtype=weight_dtype)
        labels = labels.to(device, non_blocking=True, dtype=weight_dtype)
        body = body.to(device, non_blocking=True, dtype=weight_dtype)
        cls = cls.to(device, non_blocking=True, dtype=weight_dtype)

        # compute output
        with torch.amp.autocast('cuda',enabled=False):
            
            latent = model.module.encode(body, return_kl=False)
            cls_logit = discriminator(latent)

            outputs = model(points, body)
            if 'kl' in outputs:
                loss_kl = outputs['kl']
                loss_kl = torch.sum(loss_kl) / loss_kl.shape[0]
            else:
                loss_kl = None

            outputs = outputs['logits']

            loss = criterion(outputs, labels)

        threshold = 0
        
        # Discriminator accuracy
        cls_pred = torch.zeros_like(cls_logit)
        cls_pred[cls_logit>=threshold] = 1
        
        cls_accuracy = (cls_pred==cls).float().mean()
        
        # IOU
        pred = torch.zeros_like(outputs)
        pred[outputs>=threshold] = 1

        accuracy = (pred==labels).float().sum(dim=1) / labels.shape[1]
        accuracy = accuracy.mean()
        intersection = (pred * labels).sum(dim=1)
        union = (pred + labels).gt(0).sum(dim=1)
        iou = intersection * 1.0 / union + 1e-5
        iou = iou.mean()

        batch_size = points.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.update(cls_accuracy=cls_accuracy.item())
        metric_logger.meters['iou'].update(iou.item(), n=batch_size)

        if loss_kl is not None:
            metric_logger.update(loss_kl=loss_kl.item())

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    # print('* loss {losses.global_avg:.3f}'
    #       .format(losses=metric_logger.loss))
    print('* iou {iou.global_avg:.3f} loss {losses.global_avg:.3f}'
          .format(iou=metric_logger.iou, losses=metric_logger.loss))

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}