import argparse
import os
import json
import math
import shutil
import utils

from allmodels import load_imagenet_generator, load_mnist_data, load_cifar10_data
from generator.FCN import *
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR
from custom_logging import get_logger
from torchvision.utils import save_image
from tqdm import tqdm
import numpy as np
import utils

g_logger = get_logger(__name__)


def to_img(x, im_mean, im_std):
    if im_mean is not None and im_std is not None:
        im_mean = torch.tensor(im_mean).cuda().view(1, x.shape[1], 1, 1).repeat(
            x.shape[0], 1, 1, 1)
        im_std = torch.tensor(im_std).cuda().view(1, x.shape[1], 1, 1).repeat(
            x.shape[0], 1, 1, 1)
        x = (x * im_std) + im_mean
    
    x = x.cpu().data
    return x


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # parser.add_argument('--device', type=int, nargs="+", default=[0])
    parser.add_argument('--config', default='generator/config/train_Imagenet_AE.json', help='config file')
    parser.add_argument('--gpuid', nargs='+', type=str, default=["0"])
    args = vars(parser.parse_args())

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(args['gpuid']) 

    with open(args['config']) as config_file:
        state = json.load(config_file)
        
    run_id = utils.get_time_stamp()

    state['ckpt_path'] = os.path.join(state['ckpt_path'], state["dataset"], run_id)
    state['preview_path'] = os.path.join(state['ckpt_path'], 'preview')

    if not os.path.exists(state['ckpt_path']):
        os.makedirs(state['ckpt_path'])
        
    if not os.path.exists(state['preview_path']):
        os.makedirs(state['preview_path'])
        
    config_basename = args['config'].split('/')[-1]
    shutil.copyfile(args['config'], os.path.join(state['ckpt_path'], config_basename))
    
    dataset_to_classes = {
        'MNIST': list(range(10)),
        'CIFAR10': list(range(10)),
        'Imagenet': np.random.choice(list(range(1000)), 10, replace=False)
    }
    
    classes_path = os.path.join(state['ckpt_path'], 'classes_ix.npy')
    np.save(classes_path, dataset_to_classes[state["dataset"]])
    
    for class_ix in dataset_to_classes[state["dataset"]]:
        g_logger.info(f"Loading generator (val) dataset for {state['dataset']} index {class_ix}...")
        
        if state["dataset"] == "MNIST":
            resolution = [1, 28, 28]
            gen_loader, test_loader, _, _ = load_mnist_data(state, mode='generator', class_ix=class_ix)

            encoder = Encoder(resolution, compress_mode=state['compress_mode'], resize_dim=None)
            decoder = Decoder(resolution, compress_mode=state['compress_mode'], original_dim=None)
            optimizer = torch.optim.Adam([
                                    {'params': encoder.parameters()}, 
                                    {'params': decoder.parameters()}],
                                         lr=state['learning_rate_G'],
                                         weight_decay=state['weight_decay'])
            scheduler = MultiStepLR(optimizer, milestones=[100, 200, 300], gamma=0.1)

            im_mean = None
            im_std = None
        elif state["dataset"] == "CIFAR10":
            resolution = [3, 32, 32]
            gen_loader, test_loader, _, _ = load_cifar10_data(state, mode='generator', class_ix=class_ix)

            encoder = Encoder(resolution, compress_mode=state['compress_mode'], resize_dim=None)
            decoder = Decoder(resolution, compress_mode=state['compress_mode'], original_dim=None)
            optimizer = torch.optim.Adam([
                                    {'params': encoder.parameters()}, 
                                    {'params': decoder.parameters()}])
            scheduler = MultiStepLR(optimizer, milestones=[150, 250], gamma=0.1)

            im_mean = None
            im_std = None
        elif state["dataset"] == "Imagenet":
            resolution = [3, 224, 224]
            _, test_loader = load_imagenet_test(state, normalize=False)
            gen_loader = load_imagenet_generator(state, normalize=False, class_ix=class_ix)

            encoder = Encoder(resolution, compress_mode=state['compress_mode'], resize_dim=state['encoder_resize'])
            decoder = Decoder(resolution, compress_mode=state['compress_mode'], original_dim=state['original_size'])
            optimizer = torch.optim.Adam([
                                    {'params': encoder.parameters()}, 
                                    {'params': decoder.parameters()}])
            scheduler = MultiStepLR(optimizer, milestones=[15, 35, 65], gamma=0.1)

            im_mean = None
            im_std = None
        else:
            raise ValueError(f"Unsupported dataset: {state['dataset']}")

        enc_ = torch.nn.DataParallel(encoder).cuda()
        dec_ = torch.nn.DataParallel(decoder).cuda()

        def linf_me(x, target):
            n = len(x)
            l = torch.norm(x - target, np.inf)
            return l / n

        criterion = nn.BCELoss()
        # criterion = nn.MSELoss()

        train_loader = gen_loader

        if state["dataset"] == "Imagenet":
            total_n = int(math.ceil(train_loader._size // state["batch_size"])) + 1
        else:
            total_n = len(train_loader)

        for epoch in range(1, state['epochs'] + 1):
            with tqdm(total=total_n) as pb:
                for i, data in enumerate(train_loader):
                    if state["dataset"] == "Imagenet":
                        input = data[0]["data"]
                    else:
                        input, _ = data
                        input = input.cuda()


                    # ===================forward=====================
                    z = enc_(input)
                    output = dec_(z)
                    loss = criterion(output, input)
                    # ss = torch.sum(torch.abs(input - output))
                    # ===================backward====================
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    str_loss = f"{loss.cpu().data.numpy():.4f}"
                    b_norm = torch.norm(input - output, p=2)
                    l2_norm = f"{b_norm / len(data):.4f}"

                    clean_out = dec_(enc_(input))
                    c_norm = torch.norm(input - clean_out, p=2)
                    l2_norm_c = f"{c_norm / len(data):.4f}"

                    pb.update(1)
                    pb.set_postfix(epoch=epoch, loss=str_loss, l2=l2_norm, clean=l2_norm_c)

                if state["dataset"] == "Imagenet":
                    train_loader.reset()

                scheduler.step()

                cx_dir = os.path.join(state['preview_path'], str(class_ix))
                if not os.path.exists(cx_dir):
                    os.makedirs(cx_dir)

                pic = to_img(input, im_mean, im_std)[:16]
                save_image(pic, os.path.join(cx_dir, f'input_grid_epoch-{epoch + 1}.png'))

                pic = to_img(output, im_mean, im_std)[:16]
                save_image(pic, os.path.join(cx_dir, f'out_grid_epoch-{epoch + 1}.png'))

                pic = to_img(clean_out, im_mean, im_std)[:16]
                save_image(pic, os.path.join(cx_dir, f'clean_grid_epoch-{epoch + 1}.png'))
                # ===================log========================
                
                if epoch % 5 == 0 and epoch > 1:
                    torch.save(encoder.state_dict(), os.path.join(state['ckpt_path'], f'{class_ix}_conv_encoder_epoch-{epoch}.pth'))
                    torch.save(decoder.state_dict(), os.path.join(state['ckpt_path'], f'{class_ix}_conv_decoder_epoch-{epoch}.pth'))

        print(f"Done, model is at {state['ckpt_path']}")
