import math
import sys
from typing import Iterable

import torch
from torch import nn

import util.misc as misc
import util.lr_sched as lr_sched
from torch.nn import functional as F
from eval import add_imgs, imgs_grid, test_incomplete_inference, patch_test_reconstruction, test_jointELBO
from torch.distributions.categorical import Categorical
from util.misc import save_checkpoint


def update_average(model_tgt, model_src, beta=0.99):
    with torch.no_grad():
        param_dict_src = model_src.state_dict()
        param_dict_tgt = model_tgt.state_dict()

        for p_name, p_tgt in param_dict_tgt.items():
            p_src = param_dict_src[p_name]
            assert (p_src is not p_tgt)
            param_dict_tgt[p_name] = beta * p_tgt.data + (1. - beta) * p_src.data

        model_tgt.load_state_dict(param_dict_tgt)
    return model_tgt


def set_requires_grad(net, requires_grad=False):
    for param in net.parameters():
        param.requires_grad = requires_grad


def set_requires_grad_D(net, requires_grad=False):
    for n, param in net.named_parameters():
        if n.find('pos_embed') < 0:
            param.requires_grad = requires_grad


def train_one_epoch(vae_net: torch.nn.Module, data_loader: Iterable,
                    optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler,
                    log_writer=None, args=None, avg_vae_net=None, scheduler=None,
                    test_data=None, p_list_ST=None, r_ind=None, data_loader_test=None):
    ''' train a spicific s/t'''
    vae_net.train(True)

    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:6f}'))
    metric_logger.add_meter('globalIT', misc.SmoothedValue(window_size=1, fmt='{value:4d}'))
    metric_logger.add_meter('loss', misc.SmoothedValue(window_size=20, fmt='{value:.1e} ({global_avg:.1e})'))
    metric_logger.add_meter('NLL', misc.SmoothedValue(window_size=20, fmt='{value:.1e} ({global_avg:.1e})'))
    metric_logger.add_meter('KLloss', misc.SmoothedValue(window_size=20, fmt='{value:.1e} ({global_avg:.1e})'))

    header = 'Epoch: [{}]'.format(epoch)
    print_freq = args.print_freq
    accum_iter = args.accum_iter
    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))


    for data_iter_step, (samples, y) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        real_imgs = samples.to(device, non_blocking=True)
        y = y.to(device)

        global_train_step = data_iter_step + len(data_loader) * epoch
        metric_logger.update(globalIT=global_train_step)

        if (args.set_scheduler == 'warmup') and (data_iter_step % accum_iter == 0):
            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
        if args.set_scheduler == 'linear':
            scheduler.step(global_train_step)
        if (args.set_scheduler == 'warmupC') and (data_iter_step % accum_iter == 0):
            lr_sched.adjust_learning_rate_C(optimizer, data_iter_step / len(data_loader) + epoch, args)
        if (args.set_scheduler == 'warmupIT') and (data_iter_step % accum_iter == 0):
            lr_sched.adjust_learning_rate_IT(optimizer, global_train_step, args)

        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        log_writer.add_scalar('optimizer/lr', optimizer.param_groups[0]["lr"], global_train_step)

        optimizer.zero_grad()

        Sratio = args.SampleSratio.sample() if args.Sratio < 0. else args.Sratio
        Tratio = args.SampleTratio.sample() if args.Tratio < 0. else args.Tratio
        loss, NLL, KLloss = vae_net(real_imgs, Sratio=Sratio, Tratio=Tratio, betaKL=args.betaKL)
        loss.backward()

        if args.clip_grad > 0.:
            torch.nn.utils.clip_grad_norm_(vae_net.parameters(), args.clip_grad)
        optimizer.step()

        log_writer.add_scalar('loss/NELBO', loss.item(), global_train_step)
        log_writer.add_scalar('loss/NLL', NLL.item(), global_train_step)
        log_writer.add_scalar('loss/KLloss', KLloss.item(), global_train_step)
        metric_logger.update(loss=loss.item())
        metric_logger.update(NLL=NLL.item())
        metric_logger.update(KLloss=KLloss.item())

        if avg_vae_net is not None:
            avg_vae_net = update_average(avg_vae_net, vae_net, beta=0.9999)

        if (global_train_step % 2000 == 0) and (test_data is not None) and (
                p_list_ST is not None) and (r_ind is not None):

            imgs_all = test_incomplete_inference(test_data, vae_net, p_list_ST, r_ind, args.distributed)
            add_imgs(imgs_all.data, args.dir_img + 'It%d.jpg' % global_train_step, nrow=len(p_list_ST))

            imgs_all = test_incomplete_inference(test_data, avg_vae_net, p_list_ST, r_ind, args.distributed)
            add_imgs(imgs_all.data, args.dir_img + 'It%d_avg.jpg' % global_train_step, nrow=len(p_list_ST))

            if data_loader_test is not None:
                testELBO, testNLL, testKL = test_jointELBO(data_loader_test, vae_net, args.betaKL, device,
                                                           args.distributed)
                log_writer.add_scalar('loss/testELBO', testELBO.item(), global_train_step)
                metric_logger.update(testELBO=testELBO.item())

                testELBOavg, testNLLavg, testKLavg = test_jointELBO(data_loader_test, avg_vae_net, args.betaKL,
                                                                    device, args.distributed)
                log_writer.add_scalar('loss/testELBOavg', testELBOavg.mean().item(), global_train_step)
                metric_logger.update(testELBOavg=testELBOavg.item())

        if (global_train_step + 1) % args.model_save_freq == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'vae_model': args.model,
                'avg_vae_state_dict': avg_vae_net.state_dict(),
                'vae_state_dict': vae_net.state_dict(),
            }, 0, args.dir_model, filename="checkpoint" + str(global_train_step))

    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)

    return avg_vae_net, vae_net


def train_one_epoch_Patch(patch_ae: torch.nn.Module, vae_net: torch.nn.Module, data_loader: Iterable,
                          optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler,
                          log_writer=None, args=None, avg_vae_net=None, scheduler=None,
                          test_data=None, p_list_ST=None, r_ind=None):
    vae_net.train(True)

    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:6f}'))
    metric_logger.add_meter('globalIT', misc.SmoothedValue(window_size=1, fmt='{value:4d}'))
    metric_logger.add_meter('loss', misc.SmoothedValue(window_size=20, fmt='{value:.1e} ({global_avg:.1e})'))
    metric_logger.add_meter('NLL', misc.SmoothedValue(window_size=20, fmt='{value:.1e} ({global_avg:.1e})'))
    metric_logger.add_meter('KLloss', misc.SmoothedValue(window_size=20, fmt='{value:.1e} ({global_avg:.1e})'))

    header = 'Epoch: [{}]'.format(epoch)
    print_freq = args.print_freq
    accum_iter = args.accum_iter
    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))

    for data_iter_step, (samples, y) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        real_imgs = samples.to(device, non_blocking=True)
        y = y.to(device)
        batch_size = real_imgs.shape[0]

        global_train_step = data_iter_step + len(data_loader) * epoch
        metric_logger.update(globalIT=global_train_step)

        if (args.set_scheduler == 'warmup') and (data_iter_step % accum_iter == 0):
            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
        if args.set_scheduler == 'linear':
            scheduler.step(global_train_step)
        if (args.set_scheduler == 'warmupC') and (data_iter_step % accum_iter == 0):
            lr_sched.adjust_learning_rate_C(optimizer, data_iter_step / len(data_loader) + epoch, args)
        if (args.set_scheduler == 'warmupIT') and (data_iter_step % accum_iter == 0):
            lr_sched.adjust_learning_rate_IT(optimizer, global_train_step, args)

        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        log_writer.add_scalar('optimizer/lr', optimizer.param_groups[0]["lr"], global_train_step)

        optimizer.zero_grad()

        Sratio = args.SampleSratio.sample() if args.Sratio < 0. else args.Sratio
        Tratio = args.SampleTratio.sample() if args.Tratio < 0. else args.Tratio

        _, bar_imgs = patch_ae.encoder(real_imgs)  
        bar_imgs = bar_imgs.reshape([batch_size, patch_ae.num_patch, patch_ae.latent_dim])  

        loss, NLL, KLloss, bar_imgs_rec = vae_net(bar_imgs, Sratio=Sratio, Tratio=Tratio, betaKL=args.betaKL)
        loss.backward()
 
        if args.clip_grad > 0.:
            torch.nn.utils.clip_grad_norm_(vae_net.parameters(), args.clip_grad)
        optimizer.step()

        log_writer.add_scalar('loss/NELBO', loss.item(), global_train_step)
        log_writer.add_scalar('loss/KLloss', KLloss.item(), global_train_step)

        metric_logger.update(loss=loss.item())
        metric_logger.update(NLL=NLL.item())
        metric_logger.update(KLloss=KLloss.item())

        if avg_vae_net is not None:
            avg_vae_net = update_average(avg_vae_net, vae_net, beta=0.9999)

        if (global_train_step % 2000 == 0) and (test_data is not None) and (
                p_list_ST is not None) and (r_ind is not None):
            imgs_all = patch_test_reconstruction(test_data, patch_ae, vae_net, p_list_ST, r_ind, args.distributed)
            add_imgs(imgs_all.data, args.dir_img + 'It%d.jpg' % global_train_step, nrow=len(p_list_ST))

        if (global_train_step + 1) % args.model_save_freq == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'vae_model': args.model,
                'avg_vae_state_dict': avg_vae_net.state_dict(),
                'vae_state_dict': vae_net.state_dict(),
            }, 0, args.dir_model, filename="checkpoint" + str(global_train_step))

    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)

    return avg_vae_net, vae_net
