import argparse
import datetime
import json
import numpy as np
import os
from copy import deepcopy

import time
from pathlib import Path
from functools import partial
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import random
from functools import partial
import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.misc import save_checkpoint

from functions import LinearLrDecay
from eval import add_imgs, imgs_grid, test_reconstruction

from engine_pretrain import vit_train_one_epoch
from data import MyCelebA
import models_PatchAE



def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_args_parser():
    parser = argparse.ArgumentParser('Patch-AE', add_help=False)

    # Hyperparams
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--gpus', default='0', type=str)

    # Data parameters
    parser.add_argument('--data_path',
                        default='./data/',
                        type=str, help='dataset path')
    parser.add_argument('--batch_size', default=128,
                        type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')

    parser.add_argument('--input_size', default=128, type=int,
                        help='images input size')
    parser.add_argument('--in_chans', default=3, type=int,
                        help='images input channels')
    parser.add_argument('--patch_size', default=16, type=int,
                        help='images patch size')

    # AE parameters
    parser.add_argument('--model', default='ViTPatchAE', type=str, metavar='MODEL',
                        choices=['ViTPatchAE'],
                        help='Name of model to train')
    parser.add_argument('--latent_dim', default=768, type=int, help='z dim')
    parser.add_argument('--noise_std', default=1., type=float, help='std of gaussian noise')
    parser.add_argument('--bar_weight', default=0.01, type=float, help='bar_weight of the bar loss')
    parser.add_argument('--vit_weight', default=0.1, type=float, help='aweight of the vit loss')
    parser.add_argument('--nfilter', default=32, type=int, help='nfilter')
    parser.add_argument('--detach_bar', action='store_true', help='whether detach bar')
    parser.add_argument('--z_loss_type',default='MSE', type=str, choices=['MSE','L4T'], help='Loss for bar reconstruction')

    # UniViT parameters
    parser.add_argument('--z_dim', default=128, type=int,
                        help='z dim')
    parser.add_argument('--embed_dim', default=256, type=int,
                        help='Encoder embed dim')
    parser.add_argument('--depth', default=6, type=int,
                        help='Encoder depth')
    parser.add_argument('--num_heads', default=8, type=int,
                        help='Encoder num_heads')
    parser.add_argument('--mlp_ratio', default=2, type=int,
                        help='Mlp ratio')
    parser.add_argument('--ELastNorm', default='LN', type=str,
                        choices=['NO', 'LN', 'IN', 'LR'],
                        help='The setting of the last norm of the Encoder')
    parser.add_argument('--DLastNorm', default='LN', type=str,
                        choices=['NO', 'LN', 'LR', 'BN'],
                        help='The setting of the last norm of the Discriminator')
    parser.add_argument('--drop', default=0.0, type=float,
                        help='dropout ratio for the Linear layer in Discriminator')
    parser.add_argument('--betaKL', default=1., type=float,
                        help='the hyperparameter beta for beta-vae')

    # Pretraining
    parser.add_argument('--resume',
                        default=None,
                        help='resume from checkpoint')

    # Optimizer params
    parser.add_argument('--lr', type=float,
                        default=2e-5,
                        metavar='LR', help='learning rate (absolute lr)')
    parser.add_argument('--set_scheduler', default='none', type=str,
                        choices=['none', 'linear', 'warmup', 'warmupC'])

    parser.add_argument('--weight_decay', type=float, default=0.,
                        help='weight decay (default: 0.05)')
    parser.add_argument('--clip_grad', type=float, default=0.,  
                        help='clip_grad (default: 5.0)')
    parser.add_argument('--epochs', default=400, type=int)
    parser.add_argument('--max_iter', default=5000000, type=int)
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')

    # Distributed training parameters
    parser.add_argument('--distributed', action='store_true')
    parser.add_argument('--num_workers', default=0, type=int)
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)
    parser.add_argument('--showFreq', type=int, default=100, help='Iters to show the generation during train')

    # Output params
    parser.add_argument('--output_dir',
                        default='./',
                        help='path where to save, empty for no saving')
    parser.add_argument('--log_dir',
                        default=None,
                        help='path where to tensorboard log')
    parser.add_argument('--print_freq', default=100, type=int)

    parser.add_argument('--eval_freq', default=20, type=int)

    return parser


def main(args):
    misc.init_distributed_mode(args)

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

    seed = args.seed + misc.get_rank()
    seed_torch(seed=seed)

    cudnn.benchmark = True

    transform_train = transforms.Compose([
        transforms.Resize(args.input_size),
        transforms.CenterCrop(args.input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    transform_test = transforms.Compose([
        transforms.Resize(args.input_size),
        transforms.CenterCrop(args.input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    dataset_train = MyCelebA(
        args.data_path,
        split='train',
        transform=transform_train,
        download=False)
    dataset_test = MyCelebA(
        args.data_path,
        split='test',
        transform=transform_test,
        download=False)

    print(dataset_train)

    if args.distributed:  
        num_tasks = misc.get_world_size()
        global_rank = misc.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        print("Sampler_train = %s" % str(sampler_train))
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    if args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=args.batch_size * 2,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        shuffle=False,
        drop_last=False,
    )


    if args.model == 'ViTPatchAE':

        vit_model = models_PatchAE.UniViT(
                img_size=args.input_size, patch_size=args.patch_size, in_chans=args.in_chans,
                x_dim=args.latent_dim, z_dim=args.z_dim,
                embed_dim=args.embed_dim, depth=args.depth, num_heads=args.num_heads, mlp_ratio=args.mlp_ratio,
                norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
                ELastNorm=args.ELastNorm, DLastNorm=args.DLastNorm)

        ae_net = models_PatchAE.ViTPatchAE(input_size=args.input_size,
                                            latent_dim=args.latent_dim,
                                            in_chans=args.in_chans,
                                            patch_size=args.patch_size,
                                            noise_std=args.noise_std,
                                            nfilter=args.nfilter,
                                            numlayer=(2, 0, 3),
                                            vit_model = vit_model
                                            )
    ae_net.to(device)

    if args.resume is not None:
        tmp = torch.load(args.resume)
        args.start_epoch = tmp['epoch']
        ae_net.load_state_dict(tmp['ae_state_dict'], strict=False)

    avg_ae_net = deepcopy(ae_net).to(device)
    if args.resume is not None:
        avg_ae_net.load_state_dict(tmp['avg_ae_state_dict'], strict=False)

    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
    if args.lr is None:  
        args.lr = args.blr * eff_batch_size / 256

    print("actual lr: %.2e" % args.lr)
    print("accumulate grad iterations: %d" % args.accum_iter)
    print("effective batch size: %d" % eff_batch_size)

    if args.distributed:
        ae_net = torch.nn.parallel.DistributedDataParallel(ae_net, device_ids=[args.gpu], find_unused_parameters=True)
        avg_ae_net = torch.nn.parallel.DistributedDataParallel(avg_ae_net, device_ids=[args.gpu],
                                                               find_unused_parameters=True)


    loss_scaler = NativeScaler()
    optimizer = torch.optim.Adam(ae_net.parameters(), lr=args.lr)

    if args.set_scheduler == 'linear':
        scheduler = LinearLrDecay(optimizer, args.lr, 0.0, 0, args.max_iter * args.d_iter)
    else:
        scheduler = None

    bs_test = 10
    test_data = next(iter(data_loader_test))[0][:bs_test].to(device)
    add_imgs(test_data, args.dir_img + 'real.jpg', nrow=10)
    print('test_data shape', test_data.shape)

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        if args.model == 'ViTPatchAE':
            avg_ae_net, ae_net = vit_train_one_epoch(
                ae_net, data_loader_train, optimizer, device, epoch, loss_scaler,
                log_writer=log_writer, args=args, avg_ae_net=avg_ae_net,
                scheduler=scheduler, test_data=test_data)

        if (epoch + 1) % args.eval_freq == 0:
            imgs_all = test_reconstruction(test_data, ae_net, args.noise_std, args.distributed)
            add_imgs(imgs_all.data, args.dir_img + 'Epoch_%d.jpg' % epoch, nrow=bs_test)
            log_writer.add_image('out_img/test_all', imgs_grid(imgs_all.data, bs_test), epoch)

        save_checkpoint({
            'epoch': epoch + 1,
            'ae_model': args.model,
            'avg_ae_state_dict': avg_ae_net.state_dict(),
            'ae_state_dict': ae_net.state_dict(),
            'vit_state_dict': ae_net.vit_model.state_dict(),
        }, 0, args.dir_model, filename="checkpoint")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus

    args.output_dir = args.output_dir + args.model + 'Ps' + str(args.patch_size) + 'bardim' + str(
        args.latent_dim) + 'barweight' + str(args.bar_weight) + 'nf' + str(args.nfilter)

    args.dir_img = args.output_dir + '/img/'
    args.dir_model = args.output_dir + '/model/'
    if args.log_dir is None:
        args.log_dir = args.output_dir
    if args.output_dir:
        Path(args.dir_img).mkdir(parents=True, exist_ok=True)
        Path(args.dir_model).mkdir(parents=True, exist_ok=True)

    main(args)
