# -*- coding: utf-8 -*-
# @Date    : 2019-07-25
# @Author  : Xinyu Gong (xy_gong@tamu.edu)
# @Link    : None
# @Version : 0.0

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import cfg
import models, models_search
import datasets
from functions import train, validate, LinearLrDecay, load_params, copy_params
from utils.utils import set_log_dir, save_checkpoint, create_logger
from utils.inception_score import _init_inception
from utils.fid_score import create_inception_graph, check_or_download_inception

from time import time
from datetime import datetime
import random
import torch
import os
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from copy import deepcopy
import math

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True


def original_main(gen_net, dis_net, train_loader, gen_avg_param, args):
    # args = cfg.parse_args()

    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    # set optimizer
    if args.optim_type == 'adam':
        gen_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, gen_net.parameters()),
                                         args.g_lr, (args.beta1, args.beta2))
        dis_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, dis_net.parameters()),
                                         args.d_lr, (args.beta1, args.beta2))
    elif args.optim_type == 'rmsprop':
        gen_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr)
        dis_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr)
    elif args.optim_type == 'sgd':
        gen_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr)
        dis_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr)
    else:
        raise Exception(f"optim_type {args.optim_type} not recognised.")

    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader))

    # initial
    fixed_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (25, args.latent_dim)))
    # gen_avg_param = copy_params(gen_net)
    start_epoch = 0
    best_fid = 1e4

    # set writer
    #if args.load_path:
    #    print(f'=> resuming from {args.load_path}')
    #    assert os.path.exists(args.load_path)
    #    checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth')
    #    assert os.path.exists(checkpoint_file)
    #    checkpoint = torch.load(checkpoint_file)
    #    start_epoch = checkpoint['epoch']
    #    best_fid = checkpoint['best_fid']
    #    gen_net.load_state_dict(checkpoint['gen_state_dict'])
    #    dis_net.load_state_dict(checkpoint['dis_state_dict'])
    #    gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
    #    dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
    #    avg_gen_net = deepcopy(gen_net)
    #    avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
    #    gen_avg_param = copy_params(avg_gen_net)
    #    del avg_gen_net

    #    args.path_helper = checkpoint['path_helper']
    #    logger = create_logger(args.path_helper['log_path'])
    #    logger.info(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
    #else:
    #    # create new log dir
    #    assert args.exp_name
    #    args.path_helper = set_log_dir('logs', args.exp_name)
    #    logger = create_logger(args.path_helper['log_path'])

    #logger.info(args)
    #writer_dict = {
    #    'writer': '.',
    #    'train_global_steps': start_epoch * len(train_loader),
    #    'valid_global_steps': start_epoch // args.val_freq,
    #}

    yield 0, 0, 0

    # train loop
    for epoch in range(int(start_epoch), int(args.max_epoch)):
        lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None
        yield from train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch,
                         steps_so_far=0, schedulers=lr_schedulers)

        #if epoch and epoch % args.val_freq == 0 or epoch == int(args.max_epoch)-1:
        #    backup_param = copy_params(gen_net)
        #    load_params(gen_net, gen_avg_param)
        #    inception_score, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict)
        #    logger.info(f'Inception score: {inception_score}, FID score: {fid_score} || @ epoch {epoch}.')
        #    load_params(gen_net, backup_param)
        #    if fid_score < best_fid:
        #        best_fid = fid_score
        #        is_best = True
        #    else:
        #        is_best = False
        #else:
        #    is_best = False

        avg_gen_net = deepcopy(gen_net)
        load_params(avg_gen_net, gen_avg_param)
        #save_checkpoint({
        #    'epoch': epoch + 1,
        #    'gen_model': args.gen_model,
        #    'dis_model': args.dis_model,
        #    'gen_state_dict': gen_net.state_dict(),
        #    'dis_state_dict': dis_net.state_dict(),
        #    'avg_gen_state_dict': avg_gen_net.state_dict(),
        #    'gen_optimizer': gen_optimizer.state_dict(),
        #    'dis_optimizer': dis_optimizer.state_dict(),
        #    'best_fid': best_fid,
        #    'path_helper': args.path_helper
        #}, is_best, args.path_helper['ckpt_path'])
        del avg_gen_net


# OUR CODE ------------------------------------

from sacred import Experiment
import wandb
from advas.utils import WandbWrapper


PROJECT_NAME='gan_regularizer'
JOB_ID = os.getenv('SLURM_JOB_ID', "")
PROC_ID = os.getenv('SLURM_PROCID', "")
ARRAY_ID = os.getenv('SLURM_ARRAY_JOB_ID', "")
ARRAY_TASK = os.getenv('SLURM_ARRAY_TASK_ID', "")
ex = Experiment(PROJECT_NAME)
torch.backends.cudnn.benchmark = True # set to False for reproducibility

# default_args = {'max_epoch': 200, 'max_iter': None, 'gen_batch_size': 64, 'dis_batch_size': 64, 'g_lr': 0.0002, 'd_lr': 0.0002, 'ctrl_lr': 0.00035, 'lr_decay': False, 'beta1': 0.0, 'beta2': 0.9, 'num_workers': 8, 'latent_dim': 128, 'img_size': 32, 'channels': 3, 'n_critic': 1, 'val_freq': 20, 'print_freq': 100, 'load_path': None, 'exp_name': None, 'd_spectral_norm': False, 'g_spectral_norm': False, 'dataset': 'cifar10', 'data_path': './data', 'init_type': 'normal', 'gf_dim': 64, 'df_dim': 64, 'gen_model': 'shared_gan', 'dis_model': 'shared_gan', 'controller': 'controller', 'eval_batch_size': 100, 'num_eval_imgs': 1000, 'bottom_width': 4, 'random_seed': 12345, 'shared_epoch': 15, 'grow_step1': 25, 'grow_step2': 55, 'max_search_iter': 90, 'ctrl_step': 30, 'ctrl_sample_batch': 1, 'hid_size': 100, 'baseline_decay': 0.9, 'rl_num_eval_img': 5000, 'num_candidate': 10, 'topk': 5, 'entropy_coeff': 0.001, 'dynamic_reset_threshold': 0.001, 'dynamic_reset_window': 500, 'arch': None}
default_args = dict(arch=None, baseline_decay=0.9, beta1=0.0, beta2=0.9,
                    bottom_width=4, channels=3, controller='controller',
                    ctrl_lr=0.00035, ctrl_sample_batch=1, ctrl_step=30,
                    d_lr=0.0002, d_spectral_norm=True, data_path='./data',
                    dataset='cifar10', df_dim=128, dis_batch_size=64,
                    dis_model='autogan_cifar10_a',
                    dynamic_reset_threshold=0.001, dynamic_reset_window=500,
                    entropy_coeff=0.001, eval_batch_size=100,
                    exp_name='autogan_cifar10_a', g_lr=0.0002,
                    g_spectral_norm=False, gen_batch_size=128,
                    gen_model='autogan_cifar10_a', gf_dim=256, grow_step1=25,
                    grow_step2=55, hid_size=100, img_size=32,
                    init_type='xavier_uniform', latent_dim=128, load_path=None,
                    lr_decay=False, max_epoch=320.0, max_iter=50000,
                    max_search_iter=90, n_classes=10, n_critic=5,
                    num_candidate=10, num_eval_imgs=50000, num_workers=6,
                    path_helper={'prefix':
                                 'logs/autogan_cifar10_a_2020_09_26_20_34_22', 'ckpt_path':
                                 'logs/autogan_cifar10_a_2020_09_26_20_34_22/Model',
                                 'log_path':
                                 'logs/autogan_cifar10_a_2020_09_26_20_34_22/Log',
                                 'sample_path':
                                 'logs/autogan_cifar10_a_2020_09_26_20_34_22/Samples'},
                    print_freq=100, random_seed=12345, rl_num_eval_img=5000,
                    shared_epoch=15, topk=5, val_freq=20)

@ex.capture
def seed_all(seed, _log):
    """Seed all devices deterministically off of seed and somewhat
    independently."""
    msg = f"Seed: {seed}"
    _log.info(msg)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


@ex.config
def cfg():

    dataset = 'cifar10'
    bottom_width = 4
    img_size = 32
    max_iter = 250000
    gen_model = 'autogan_cifar10_a'
    dis_model = 'autogan_cifar10_a'
    latent_dim = 128
    gf_dim = 256
    df_dim = 128
    g_spectral_norm = False
    d_spectral_norm = True
    g_lr = 2e-4
    d_lr = 2e-4
    beta1 = 0.0
    beta2 = 0.9
    init_type = 'xavier_uniform'
    n_critic = 5
    val_freq = 20
    exp_name = 'autogan_cifar10_a'

    gen_batch_size = 128
    dis_batch_size = 64
    num_workers = 0
    data_path = './data'

    sampled_architecture = None
    regularizer_strength = 0  
    advas_remove_hinge_loss = True
    instance_norm = False
    optim_type = 'adam'
    wandb_id = None
    N_limit = None
    tag = 'autogan-tester'
    local_restore = None  # os.path.join(os.environ['SCRATCH'], 'gan-ckpts', 'autogan') 
    local_save = None  # os.path.join(os.environ['SCRATCH'], 'gan-ckpts', 'autogan')


@ex.automain
def main(dataset, bottom_width, img_size, max_iter, gen_model, dis_model, latent_dim,
         gf_dim, df_dim, g_spectral_norm, d_spectral_norm, g_lr, d_lr, beta1, beta2, init_type,
         n_critic, val_freq, exp_name, wandb_id, gen_batch_size, dis_batch_size, num_workers, data_path,
         N_limit, sampled_architecture, regularizer_strength, advas_remove_hinge_loss, tag, instance_norm,
         optim_type, _config, _log):

    # set all seeds
    if wandb_id is None:
        seed_all()
    d = datetime.now().isoformat()
    job_suffix = ""
    if ARRAY_ID:
        job_suffix += f"_task_{ARRAY_ID}"
        job_suffix += f"_{ARRAY_TASK}"
    elif JOB_ID:
        job_suffix += f"_job_{JOB_ID}"
        if PROC_ID:
            job_suffix += f"_{PROC_ID}"

    if wandb_id is not None:
        id = wandb_id
    else:
        id = wandb.util.generate_id()
        wandb.init(project=PROJECT_NAME, name='autogan'+exp_name+job_suffix,
                   dir="/tmp", group='autogan', id=id, resume='allow', 
                   tags=None if tag is None else [tag])
    if wandb_id is None:
        for key, value in _config.items():
            wandb.config[key] = value
    if wandb_id is not None:
        if wandb.run.id != wandb_id:
            raise ValueError("Wandb ids do not match!")
        run_path = wandb.run.path

    # args = argparse.Namespace(
    args = argparse.Namespace(**_config)
    for argname, argval in default_args.items():
        if not hasattr(args, argname):
            setattr(args, argname, argval)

    if args.sampled_architecture is None:
        # import network
        gen_net = eval('models.'+args.gen_model+'.Generator')(args=args).cuda()
        dis_net = eval('models.'+args.dis_model+'.Discriminator')(args=args).cuda()
    else:
        gen_net = models_search.shared_gan.Generator(args=args).cuda()
        dis_net = models_search.shared_gan.Discriminator(args=args).cuda()
        gen_net.set_arch(args.sampled_architecture, cur_stage=2)
        dis_net.cur_stage = 2

    print(gen_net)
    print(dis_net)

    class Generator():
        net = gen_net
        def sample(self, batch_shape=torch.Size([1]), fixed=False):
            self.net.eval()
            if fixed:
                z = self.sample_fixed_latent(batch_shape[0])
            else:
                z = torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_shape[0], args.latent_dim)))
            imgs = self.net(z)
            self.net.train()
            return imgs
        def eval(self,):
            pass
        def parameters(self,):
            return gen_net.parameters()
        def sample_fixed_latent(self, n):
            name = f'./latents_{n}_{args.latent_dim}.pt'
            if not os.path.exists(name):
                print('Generating new latent parameters.')
                z = torch.randn(n, args.latent_dim).float()
                torch.save(z, name)
            z = torch.load(name)
            return z.cuda()

    class GeneratorEMA(Generator):
        def update_ema(self):
            self.net = deepcopy(gen_net)
            load_params(self.net, gen_avg_param)

    generator = Generator()
    generator_ema = GeneratorEMA()


    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train
    sup_dataset = train_loader.dataset
    class UnsupDataset():
        def __getitem__(self, i):
            return sup_dataset[i][0]
        def __len__(self):
            return len(sup_dataset)
    unsup_dataset = UnsupDataset()

    wandb_boi = WandbWrapper(wandb, gen_batch_size, num_workers, N_limit=N_limit,
                             device='cuda', normalize_metric_input=False)
    gen_avg_param = copy_params(gen_net)

    print('\niters per epoch\m', int(len(train_loader.dataset)/gen_batch_size))
    total_running_time = 0
    start_iter_time = time()
    for i, (epoch, advas_loss, advas_normalizer) in enumerate(original_main(gen_net, dis_net, train_loader, gen_avg_param, args=args)):
        if i % 500 == 0:
            print('iteration', i)
        total_running_time += time() - start_iter_time

        # logging ----------------------------------------------------------------
        iters_per_epoch = int(len(train_loader.dataset)/gen_batch_size)
        at_logscale_point = (i==0) or (math.log(i, 2) % 1 == 0)
        end_fifth_epoch = (i % (5*iters_per_epoch) == 0)
        end_fiftieth_epoch = (i % (50*iters_per_epoch) == 0)
        do_big_log = end_fiftieth_epoch and (i != 0)
        do_little_log = end_fifth_epoch or at_logscale_point
        do_add_images_and_ckpt = do_little_log or do_big_log or (i%500 == 0)
        do_add_anything = do_little_log or do_big_log or (i%50 == 0)

        if do_big_log:
            inception, fid = validate(
                    args, fid_stat='fid_stat/fid_stats_cifar10_train.npz',
                    gen_net=gen_net, N=50000, do_fid=True)
            wandb.log({'inception-50000': inception, 'fid-50000': fid},
                      commit=False)
            wandb_boi.swd_metric(generator, sup_dataset, N=16384)
            generator_ema.update_ema()
            inception, fid = validate(
                    args, fid_stat='fid_stat/fid_stats_cifar10_train.npz',
                    gen_net=generator_ema.net, do_fid=True, N=50000)
            wandb.log({'inception-50000-ema': inception, 'fid-50000-ema': fid},
                      commit=False)
            wandb_boi.swd_metric(generator_ema, sup_dataset, N=50000, label='ema',
                                 allowed_gpu_mem_per_1000_images=1e6)
        if do_little_log:
            N = 1000
            inception, fid = validate(
                    args, fid_stat='fid_stat/fid_stats_cifar10_train.npz',
                    gen_net=gen_net, do_fid=True, N=1000)
            wandb.log({'inception-1000': inception, 'fid-1000': fid},
                      commit=False)
            wandb_boi.swd_metric(generator, sup_dataset, N=N)
            generator_ema.update_ema()
            inception, fid = validate(
                    args, fid_stat='fid_stat/fid_stats_cifar10_train.npz',
                    gen_net=generator_ema.net, do_fid=True, N=1000)
            wandb.log({'inception-1000-ema': inception, 'fid-1000-ema': fid},
                      commit=False)
            wandb_boi.swd_metric(generator_ema, sup_dataset, N=N, label='ema')

        if do_add_images_and_ckpt:
            images = list(generator.sample([3])) + \
                    list(generator.sample([3], fixed=True)) + \
                    list(generator_ema.sample([3])) + \
                    list(generator_ema.sample([3], fixed=True)) + \
                    [unsup_dataset[i%len(unsup_dataset)]]
            names = [f'random_{img_no}' for img_no in range(3)] + \
                    [f'fixed_{img_no}' for img_no in range(3)] + \
                    [f'ema_random_{img_no}' for img_no in range(3)] + \
                    [f'ema_fixed_{img_no}' for img_no in range(3)] + \
                    ['data']
            wandb_boi.add_images(images, names, iteration=i)
        if do_add_anything:
            wandb_boi.track_summary_stats(generator, sup_dataset)
            wandb.log({'advas_loss': advas_loss, 'advas_normalizer': advas_normalizer}, commit=False)
            wandb_boi.log(iteration=i, running_time=total_running_time)

        start_iter_time = time()
