import argparse
import os
import json
import math
import shutil
import utils

from allmodels import MNIST, load_model, load_mnist_data, load_cifar10_data, CIFAR10, load_imagenet_train, load_imagenet_test

from generator.FCN import *
from generator.VAE import *
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR
from custom_logging import get_logger
from torchvision.utils import save_image
from tqdm import tqdm
import numpy as np
import utils

g_logger = get_logger(__name__)


def to_img(x, im_mean, im_std):
    if im_mean is not None and im_std is not None:
        im_mean = torch.tensor(im_mean).cuda().view(1, x.shape[1], 1, 1).repeat(
            x.shape[0], 1, 1, 1)
        im_std = torch.tensor(im_std).cuda().view(1, x.shape[1], 1, 1).repeat(
            x.shape[0], 1, 1, 1)
        x = (x * im_std) + im_mean
    
    x = x.cpu().data
    return x


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # parser.add_argument('--device', type=int, nargs="+", default=[0])
    parser.add_argument('--config', default='generator/config/train_Imagenet_AE.json', help='config file')
    parser.add_argument('--gpuid', nargs='+', type=str, default=["0"])
    args = vars(parser.parse_args())

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(args['gpuid']) 

    with open(args['config']) as config_file:
        state = json.load(config_file)
        
    run_id = utils.get_time_stamp()

    state['ckpt_path'] = os.path.join(state['ckpt_path'], state["dataset"], run_id)
    state['preview_path'] = os.path.join(state['ckpt_path'], 'preview')

    if not os.path.exists(state['ckpt_path']):
        os.makedirs(state['ckpt_path'])
        
    if not os.path.exists(state['preview_path']):
        os.makedirs(state['preview_path'])
        
    config_basename = args['config'].split('/')[-1]
    shutil.copyfile(args['config'], os.path.join(state['ckpt_path'], config_basename))
    
    architecture = state['architecture']

    g_logger.info(f"Loading generator (val) dataset for {state['dataset']}...")
    if state["dataset"] == "MNIST":
        gen_loader, _, _, _ = load_mnist_data(state, mode='generator')
        
        resolution = [1, 28, 28]
        if architecture == "AE":
            encoder = Encoder(resolution, compress_mode=state['compress_mode'], resize_dim=None)
            decoder = Decoder(resolution, compress_mode=state['compress_mode'], original_dim=None)
        elif architecture == "VAE" or architecture == "VAE-GAN":
            encoder = VariationalEncoder(resolution, compress_mode=state['compress_mode'], resize_dim=None)
            decoder = VariationalDecoder(resolution, compress_mode=state['compress_mode'], original_dim=None)
            
            if architecture == "VAE-GAN":
                disc = MNIST()
                disc = torch.nn.DataParallel(disc).cuda()
                load_model(disc, state['disc_path'])
                disc.train()
                disc_optimizer = torch.optim.Adam(disc.parameters())
            
        # scheduler = MultiStepLR(optimizer, milestones=[100, 200, 300], gamma=0.1)
        
        im_mean = None
        im_std = None
    elif state["dataset"] == "CIFAR10":
        gen_loader, _, _, _ = load_cifar10_data(state, mode='generator')

        resolution = [3, 32, 32]
        if architecture == "AE":
            encoder = Encoder(resolution, compress_mode=state['compress_mode'], resize_dim=None)
            decoder = Decoder(resolution, compress_mode=state['compress_mode'], original_dim=None)
        elif architecture == "VAE" or architecture == "VAE-GAN":
            encoder = VariationalEncoder(resolution, compress_mode=state['compress_mode'], resize_dim=None)
            decoder = VariationalDecoder(resolution, compress_mode=state['compress_mode'], original_dim=None)
            
            if architecture == "VAE-GAN":
                disc = CIFAR10()
                disc = torch.nn.DataParallel(disc).cuda()
                load_model(disc, state['disc_path'])
                disc.train()
                disc_optimizer = torch.optim.Adam(disc.parameters())
                
#                                      lr=state['learning_rate_G'],
#                                      weight_decay=state['weight_decay'])
        # scheduler = MultiStepLR(optimizer, milestones=[15, 25], gamma=0.1)
        
        im_mean = None
        im_std = None
    elif state["dataset"] == "Imagenet":
        gen_loader = load_imagenet_generator(state, normalize=False)
        
        resolution = [3, 224, 224]
        if architecture == "AE":
            encoder = Encoder(resolution, compress_mode=state['compress_mode'], resize_dim=state['encoder_resize'])
            decoder = Decoder(resolution, compress_mode=state['compress_mode'], original_dim=state['original_size'])
        elif architecture == "VAE" or architecture == "VAE-GAN":
            encoder = VariationalEncoder(resolution, compress_mode=state['compress_mode'], esize_dim=state['encoder_resize'])
            decoder = VariationalDecoder(resolution, compress_mode=state['compress_mode'], original_dim=state['original_size'])

#         optimizer = torch.optim.Adam(model.parameters(),
#                                      lr=state['learning_rate_G'],
#                                      weight_decay=state['weight_decay'])
        # scheduler = MultiStepLR(optimizer, milestones=[15, 35, 65], gamma=0.1)
        
        im_mean = None
        im_std = None
#         im_mean = [0.485, 0.456, 0.406]
#         im_std = [0.229, 0.224, 0.225]
    else:
        raise ValueError(f"Unsupported dataset: {state['dataset']}")

    
    enc_optimizer = torch.optim.Adam([
                                {'params': encoder.parameters()}],
                         lr=state['learning_rate_G'],
                         weight_decay=state['weight_decay'])
    dec_optimizer = torch.optim.Adam([
                                {'params': decoder.parameters()}],
                         lr=state['learning_rate_G'],
                         weight_decay=state['weight_decay'])
    
    enc_ = torch.nn.DataParallel(encoder).cuda()
    dec_ = torch.nn.DataParallel(decoder).cuda()
    print(enc_)
    print(dec_)
        
    def linf_me(x, target):
        n = len(x)
        l = torch.norm(x - target, np.inf)
        return l / n
    
    if state["order"] == "inf":
        print("Using linf criterion.")
        criterion = linf_me
    else:
        criterion = nn.BCELoss()
#         crit2 = nn.MSELoss()
#         criterion = nn.MSELoss()
        
    train_loader = gen_loader
    
    if state["dataset"] == "Imagenet":
        total_n = int(math.ceil(train_loader._size // state["batch_size"])) + 1
    else:
        total_n = len(train_loader)
    
    for epoch in range(state['epochs']):
        with tqdm(total=total_n) as pb:
            for i, data in enumerate(train_loader):
                if state["dataset"] == "Imagenet":
                    input = data[0]["data"]
                else:
                    input, _ = data
                    input = input.cuda()
                    
            
                # ===================forward=====================
                if state['architecture'] == 'AE':
                    z = enc_(input)
                    output = dec_(z)
                    loss = criterion(output, input)
                # ===================backward====================
                    enc_optimizer.zero_grad()
                    dec_optimizer.zero_grad()
                    loss.backward()
                    enc_optimizer.step()
                    dec_optimizer.step()
                    
                    str_dict = {
                        "loss": f"{loss.cpu().data.numpy():.4f}"
                    }
                        
                elif state['architecture'] == 'VAE':
                    gamma = 0.5
                    beta = 1
                    eta = 0.1
                # ===================forward=====================
                    
                    mu, logvar = enc_(input)
                    noise = reparameterize(mu, logvar)
                    output = dec_(noise)
                    
                # ===================backward====================
                    lat_loss = latent_loss(mu, logvar)
                    # rec_loss = reconstruction_loss(input, output)
                    rec_loss = criterion(output, input)
                    var_loss = torch.sum(torch.abs(logvar))
                    
                    loss = lat_loss + beta*rec_loss + eta*var_loss
                    
                    enc_optimizer.zero_grad()
                    dec_optimizer.zero_grad()
                    loss.backward()
                    enc_optimizer.step()
                    dec_optimizer.step()
                    
                    str_dict = {
                        "loss": f"{loss.cpu().data.numpy():.4f}"
                    }
                    
                elif state['architecture'] == 'VAE-GAN':
                    beta = 1.0
                    gamma = 0.1
                    eta = 0.1
                    
                # ===================forward=====================
                    mu, logvar = enc_(input)
                    noise = reparameterize(mu, logvar)
                    output = dec_(noise)
                    
                    out_real_classes = disc(input)
                    out_gen_classes = disc(output)
                    
                # ===================backward====================
                    lat_loss = latent_loss(mu, logvar)
                    # rec_loss = reconstruction_loss(input, output)
                    rec_loss = criterion(output, input)
                    var_loss = torch.sum(torch.abs(logvar))
                    dec_loss = decoder_loss(out_gen_classes)
                    discloss = discriminator_loss(out_real_classes, out_gen_classes)
                    
                    enc_optimizer.zero_grad()
                    loss = lat_loss + beta*rec_loss + gamma*dec_loss + eta*var_loss
                    loss.backward(retain_graph=True)
                    
                    dec_optimizer.zero_grad()
                    loss = beta*rec_loss + gamma*dec_loss
                    loss.backward(retain_graph=True)
                    
                    disc_optimizer.zero_grad()
                    loss = gamma*discloss
                    loss.backward()
                    enc_optimizer.step()
                    dec_optimizer.step()
                    disc_optimizer.step()
                    
                    
                    str_dict = {
                        # 'lat_loss': f"{lat_loss.cpu().data.numpy():.4f}",
                        'rec_loss': f"{rec_loss.cpu().data.numpy():.4f}",
                        # 'var_loss': f"{var_loss.cpu().data.numpy():.4f}",
                        'dec_loss': f"{dec_loss.cpu().data.numpy():.4f}",
                        # 'disc loss': f"{discloss.cpu().data.numpy():.4f}"
                    }
                    
                else:
                    raise NotImplementedError()
                        
                b_norm = torch.norm(input - output, p=2)
                str_dict["l2"] = f"{b_norm / len(data):.4f}"
                
                pb.update(1)
                pb.set_postfix(**str_dict)

            if state["dataset"] == "Imagenet":
                train_loader.reset()

            # scheduler.step()
            
            pic = to_img(input, im_mean, im_std)[:16]
            save_image(pic, os.path.join(state['preview_path'], f'input_grid_epoch-{epoch + 1}.png'))
                
            pic = to_img(output, im_mean, im_std)[:16]
            save_image(pic, os.path.join(state['preview_path'], f'out_grid_epoch-{epoch + 1}.png'))
            
            # ===================log========================

            torch.save(encoder.state_dict(), os.path.join(state['ckpt_path'], f'conv_encoder_epoch-{epoch + 1}.pth'))
            torch.save(decoder.state_dict(), os.path.join(state['ckpt_path'], f'conv_decoder_epoch-{epoch + 1}.pth'))
    
    print(f"Done, model is at {state['ckpt_path']}")
