import math
import sys
from typing import Iterable

import torch
from torch import detach, nn

import util.misc as misc
import util.lr_sched as lr_sched
from torch.nn import functional as F
from eval import add_imgs, imgs_grid, test_reconstruction
from torch.distributions.categorical import Categorical


def update_average(model_tgt, model_src, beta=0.99):
    with torch.no_grad():
        param_dict_src = model_src.state_dict()
        param_dict_tgt = model_tgt.state_dict()

        for p_name, p_tgt in param_dict_tgt.items():
            p_src = param_dict_src[p_name]
            assert (p_src is not p_tgt)
            param_dict_tgt[p_name] = beta * p_tgt.data + (1. - beta) * p_src.data

        model_tgt.load_state_dict(param_dict_tgt)
    return model_tgt


def vit_train_one_epoch(ae_net: torch.nn.Module, data_loader: Iterable,
                    optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler,
                    log_writer=None, args=None, avg_ae_net=None, scheduler=None, test_data=None,
                   ):

    ae_net.train(True)

    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:6f}'))
    metric_logger.add_meter('globalIT', misc.SmoothedValue(window_size=1, fmt='{value:4d}'))
    metric_logger.add_meter('loss', misc.SmoothedValue(window_size=20, fmt='{value:.1e} ({global_avg:.1e})'))

    header = 'Epoch: [{}]'.format(epoch)
    print_freq = args.print_freq
    accum_iter = args.accum_iter
    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))

    for data_iter_step, (samples, y) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        real_imgs = samples.to(device, non_blocking=True)
        y = y.to(device)

        global_train_step = data_iter_step + len(data_loader) * epoch
        total_train_step = len(data_loader) *args.epochs
        metric_logger.update(globalIT=global_train_step)

        if (args.set_scheduler == 'warmup') and (data_iter_step % accum_iter == 0):
            lr_sched.cosine_learning_rate_with_linear_warmup(optimizer, global_train_step, total_train_step, args)
        if args.set_scheduler == 'linear':
            scheduler.step(global_train_step)
        if (args.set_scheduler == 'warmupC') and (data_iter_step % accum_iter == 0):
            lr_sched.adjust_learning_rate_C(optimizer, data_iter_step / len(data_loader) + epoch, args)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        log_writer.add_scalar('optimizer/lr', optimizer.param_groups[0]["lr"], global_train_step)

        optimizer.zero_grad()

        outdict = ae_net(real_imgs, detach_bar=args.detach_bar, z_loss_type=args.z_loss_type)

        loss = outdict['x_loss'] + args.vit_weight * outdict['z_loss'] + args.bar_weight * outdict['bar_loss']


        log_writer.add_scalar('bar/bar_loss', outdict['bar_loss'].item(), global_train_step)
        log_writer.add_scalar('loss/z_loss', outdict['z_loss'].item(), global_train_step)
        log_writer.add_scalar('loss/x_loss', outdict['x_loss'].item(), global_train_step)

        metric_logger.update(x_loss=outdict['x_loss'].item())
        metric_logger.update(z_loss=outdict['z_loss'].item())

        loss.backward()

        if args.clip_grad > 0.:
            torch.nn.utils.clip_grad_norm_(ae_net.parameters(), args.clip_grad)
        optimizer.step()

        log_writer.add_scalar('loss/total_loss', loss.item(), global_train_step)
        metric_logger.update(loss=loss.item())

        if avg_ae_net is not None:
            avg_ae_net = update_average(avg_ae_net, ae_net, beta=0.9999)
        
        if (global_train_step % 2000 == 0) and (test_data is not None):
            imgs_all = test_reconstruction(test_data, ae_net, args.noise_std, args.distributed)
            add_imgs(imgs_all.data, args.dir_img + 'IT%d.jpg' % global_train_step, nrow=test_data.shape[0])
            log_writer.add_image('out_img/test_all', imgs_grid(imgs_all.data, test_data.shape[0]), epoch)

    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)

    return avg_ae_net, ae_net