import argparse
import datetime
import json
import numpy as np
import os
from copy import deepcopy

import time
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import scipy.io as sio
import random


import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.misc import save_checkpoint

from functions import LinearLrDecay
from eval import  add_imgs, imgs_grid, test_incomplete_inference

import models_BigLearnVAE
from engine_pretrain import train_one_epoch
from torch.distributions.beta import Beta
from functools import partial


def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_args_parser():
    parser = argparse.ArgumentParser('BigLearn-VAE', add_help=False)

    # Hyperparams
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--gpus', default='0', type=str)

    # Data parameters
    parser.add_argument('--data_path',
                        default='./data/',
                        type=str, help='dataset path')
    parser.add_argument('--batch_size', default=128,
                        type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')

    parser.add_argument('--input_size', default=32, type=int,
                        help='images input size')
    parser.add_argument('--in_chans', default=3, type=int,
                        help='images input channels')
    parser.add_argument('--patch_size', default=4, type=int,
                        help='images patch size')

    # Model parameters
    parser.add_argument('--model', default='TwoViT', type=str, metavar='MODEL',
                        choices=['NaiveVAE', 'TwoViT', 'UniViT'],
                        help='Name of VAE model to train')
    parser.add_argument('--z_dim', default=32, type=int,
                        help='z dim')
    parser.add_argument('--embed_dim', default=128, type=int,
                        help='Encoder embed dim')
    parser.add_argument('--depth', default=4, type=int,
                        help='Encoder depth')
    parser.add_argument('--num_heads', default=8, type=int,
                        help='Encoder num_heads')
    parser.add_argument('--decoder_embed_dim', default=128, type=int,
                        help='Decoder embed_dim: twice that of Encoder')
    parser.add_argument('--decoder_depth', default=4, type=int,
                        help='Decoder depth')
    parser.add_argument('--decoder_num_heads', default=8, type=int,
                        help='Decoder num_heads')
    parser.add_argument('--mlp_ratio', default=2, type=int,
                        help='Mlp ratio')
    parser.add_argument('--ELastNorm', default='LN', type=str,
                        choices=['NO', 'LN', 'IN', 'LR'],
                        help='The setting of the last norm of the Encoder')
    parser.add_argument('--DLastNorm', default='LN', type=str,
                        choices=['NO', 'LN', 'LR', 'BN'],
                        help='The setting of the last norm of the Discriminator')
    parser.add_argument('--drop', default=0.0, type=float,
                        help='dropout ratio for the Linear layer in Discriminator')

    # Pretraining
    parser.add_argument('--resume',
                        default=None,
                        help='resume from checkpoint')

    # Loss parameters
    parser.add_argument('--betaKL', default=1., type=float,
                        help='the hyperparameter beta for beta-vae')
    parser.add_argument('--Sratio', default=-1., type=float,
                        help='Model to data: source ratio. a in [0,1]: constant Sratio. -1: random')
    parser.add_argument('--Tratio', default=-1., type=float,
                        help='Model to data: target ratio. a in [0,1]: constant Sratio. -1: random')
    parser.add_argument('--AProbAlpha', default=0.5, type=float,
                        help='Beta parameter alpha for A.')
    parser.add_argument('--AProbBeta', default=3., type=float,
                        help='Beta parameter beta for A.')
    parser.add_argument('--BProbAlpha', default=3, type=float,
                        help='Beta parameter alpha for B.')
    parser.add_argument('--BProbBeta', default=0.5, type=float,
                        help='Beta parameter beta for B.')
    parser.add_argument('--CProbAlpha', default=3., type=float,
                        help='Beta parameter alpha for C.')
    parser.add_argument('--CProbBeta', default=0.5, type=float,
                        help='Beta parameter beta for C.')

    # Optimizer params
    parser.add_argument('--lr', type=float,
                        default=1e-4,
                        metavar='LR', help='learning rate (absolute lr)')
    parser.add_argument('--set_scheduler', default='warmupIT', type=str,
                        choices=['none', 'linear', 'warmup', 'warmupIT'])
    parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
    parser.add_argument('--min_lr', type=float, default=2e-6, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')
    parser.add_argument('--warmup_iters', type=int, default=10000, metavar='N',
                        help='iters to warmup LR')
    parser.add_argument('--cooling_iters', default=50000, type=int)
    parser.add_argument('--epochs', default=500, type=int)

    parser.add_argument('--max_iter', default=1000000, type=int, help='not used')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--weight_decay', type=float, default=0.,
                        help='weight decay (default: 0.05)')
    parser.add_argument('--clip_grad', type=float, default=2.,  # 2.0
                        help='clip_grad (default: 5.0)')
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')

    # Distributed training parameters
    parser.add_argument('--distributed', action='store_true')
    parser.add_argument('--num_workers', default=0, type=int)
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)
    parser.add_argument('--showFreq', type=int, default=100, help='Iters to show the generation during train')

    # Output params
    parser.add_argument('--output_dir', 
                        default='./output_MNIST2/',
                        help='path where to save, empty for no saving')
    parser.add_argument('--log_dir',
                        default=None,
                        help='path where to tensorboard log')
    parser.add_argument('--print_freq', default=100, type=int)
    parser.add_argument('--model_save_freq', default=20000, type=int)

    return parser


def main(args):
    misc.init_distributed_mode(args)

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

    seed = args.seed + misc.get_rank()
    seed_torch(seed=seed)

    cudnn.benchmark = True

    transform_train = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize(args.input_size),
        transforms.CenterCrop(args.input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    transform_test = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize(args.input_size),
        transforms.CenterCrop(args.input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    dataset_train = torchvision.datasets.MNIST(args.data_path, train=True, download=True, transform=transform_train)
    dataset_test = torchvision.datasets.MNIST(args.data_path, train=False, download=True, transform=transform_test)

    print(dataset_train)

    if args.distributed:
        num_tasks = misc.get_world_size()
        global_rank = misc.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        print("Sampler_train = %s" % str(sampler_train))
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    if args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        shuffle=True,
        drop_last=True,
    )
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        shuffle=False,
        drop_last=False,
    )

    if args.model == 'NaiveVAE':
        vae_net = models_BigLearnVAE.NaiveVariationalAutoEncoder(
            img_size=args.input_size, patch_size=args.patch_size, in_chans=args.in_chans, z_dim=args.z_dim,
            embed_dim=args.embed_dim, depth=args.depth, num_heads=args.num_heads,
            decoder_embed_dim=args.decoder_embed_dim, decoder_depth=args.decoder_depth,
            decoder_num_heads=args.decoder_num_heads, mlp_ratio=args.mlp_ratio,
            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
            ELastNorm=args.ELastNorm, DLastNorm=args.DLastNorm)
    elif args.model == 'TwoViT':
        vae_net = models_BigLearnVAE.VariationalAutoEncoder2ViT(
            img_size=args.input_size, patch_size=args.patch_size, in_chans=args.in_chans, z_dim=args.z_dim,
            embed_dim=args.embed_dim, depth=args.depth, num_heads=args.num_heads,
            decoder_embed_dim=args.decoder_embed_dim, decoder_depth=args.decoder_depth,
            decoder_num_heads=args.decoder_num_heads, mlp_ratio=args.mlp_ratio,
            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
            ELastNorm=args.ELastNorm, DLastNorm=args.DLastNorm)
    elif args.model == 'UniViT':
        vae_net = models_BigLearnVAE.VariationalAutoEncoder1ViT(
            img_size=args.input_size, patch_size=args.patch_size, in_chans=args.in_chans, z_dim=args.z_dim,
            embed_dim=args.embed_dim, depth=args.depth, num_heads=args.num_heads, mlp_ratio=args.mlp_ratio,
            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
            ELastNorm=args.ELastNorm, DLastNorm=args.DLastNorm)
        
    vae_net.to(device)
    avg_vae_net = deepcopy(vae_net).to(device)

    if args.resume is not None:
        tmp = torch.load(args.resume)
        args.start_epoch = tmp['epoch']
        vae_net.load_state_dict(tmp['vae_state_dict'], strict=True)
        avg_vae_net.load_state_dict(tmp['avg_vae_state_dict'], strict=True)

    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
    if args.lr is None: 
        args.lr = args.blr * eff_batch_size / 256

    print("actual lr: %.2e" % args.lr)
    print("accumulate grad iterations: %d" % args.accum_iter)
    print("effective batch size: %d" % eff_batch_size)

    if args.distributed:
        vae_net = torch.nn.parallel.DistributedDataParallel(vae_net, device_ids=[args.gpu], find_unused_parameters=True)
        avg_vae_net = torch.nn.parallel.DistributedDataParallel(avg_vae_net, device_ids=[args.gpu],
                                                                find_unused_parameters=True)

    loss_scaler = NativeScaler()
    optimizer = torch.optim.AdamW(vae_net.parameters(), lr=args.lr, betas=(0.0, 0.999), eps=1e-8)

    if args.set_scheduler == 'linear':
        scheduler = LinearLrDecay(optimizer, args.lr, 0.0, 0, args.max_iter * args.d_iter)
    else:
        scheduler = None

    bs_test = 10
    test_data = next(iter(data_loader_test))[0][:bs_test].to(device)
    add_imgs(test_data, args.dir_img + 'real.jpg', nrow=10)
    print('test_data shape', test_data.shape)

    p_list_ST = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.]
    N, L, P = bs_test, (args.input_size // args.patch_size) ** 2, len(p_list_ST)
    r_ind = [torch.randperm(L).to(device) for i in range(3)]
    r_ind = torch.cat(r_ind, dim=0).view(3, L)


    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        avg_vae_net, vae_net = train_one_epoch(
            vae_net, data_loader_train, optimizer, device, epoch, loss_scaler,
            log_writer=log_writer, args=args, avg_vae_net=avg_vae_net, scheduler=scheduler,
            test_data=test_data, p_list_ST=p_list_ST, r_ind=r_ind, data_loader_test=data_loader_test)

        imgs_all = test_incomplete_inference(test_data, vae_net, p_list_ST, r_ind, args.distributed)
        add_imgs(imgs_all.data, args.dir_img + 'Epoch_%d.jpg' % epoch, nrow=P)
        log_writer.add_image('out_img/test_all', imgs_grid(imgs_all.data, P), epoch)

        imgs_all = test_incomplete_inference(test_data, avg_vae_net, p_list_ST, r_ind, args.distributed)
        add_imgs(imgs_all.data, args.dir_img + 'Epoch_%d_avg.jpg' % epoch, nrow=P)
        log_writer.add_image('out_img/test_all_avg', imgs_grid(imgs_all.data, P), epoch)

        save_checkpoint({
            'epoch': epoch + 1,
            'vae_model': args.model,
            'avg_vae_state_dict': avg_vae_net.state_dict(),
            'vae_state_dict': vae_net.state_dict(),
        }, 0, args.dir_model, filename="checkpoint")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus

    args.output_dir = args.output_dir + args.model + str(args.depth) + str(args.decoder_depth) + \
                      'Ps' + str(args.patch_size) + 'Z' + str(args.z_dim) + 'bKL' + str(args.betaKL) + \
                      'Bz' + str(args.batch_size) + 'LR' + str(args.lr) + \
                      'S' + str(args.Sratio) + 'T' + str(args.Tratio) + 'S' + str(args.seed) + \
                      'E' + str(args.embed_dim) + 'H' + str(args.num_heads) + args.ELastNorm + args.DLastNorm

    args.dir_img = args.output_dir + '/img/'
    args.dir_model = args.output_dir + '/model/'
    if args.log_dir is None:
        args.log_dir = args.output_dir
    if args.output_dir:
        Path(args.dir_img).mkdir(parents=True, exist_ok=True)
        Path(args.dir_model).mkdir(parents=True, exist_ok=True)

    args.SampleSratio = Beta(args.AProbAlpha, args.AProbBeta) 
    args.SampleTratio = Beta(args.BProbAlpha, args.BProbBeta) 

    main(args)
