import os
import itertools as it
import random
from datetime import datetime
import math
from time import time
from pathlib import Path

from retry.api import retry_call
import numpy as np
import torch
import torch.nn.functional as F
from sacred import Experiment
import torchvision
import torchvision.transforms as T

import wandb
from tqdm import tqdm
from stylegan2_pytooch import Trainer
import eval as evaluation
from advas.utils import WandbWrapper

from facenet_pytorch import InceptionResnetV1

PROJECT_NAME = "gan-regularizer"
JOB_ID = os.getenv('SLURM_JOB_ID', "")
PROC_ID = os.getenv('SLURM_PROCID', "")
ARRAY_ID = os.getenv('SLURM_ARRAY_JOB_ID', "")
ARRAY_TASK = os.getenv('SLURM_ARRAY_TASK_ID', "")

ex = Experiment(PROJECT_NAME)


class NanException(Exception):
    pass


@ex.named_config
def cuda():
    pass


@ex.config
def cfg():

    wandb_id = None
    local_restore = None
    celeba_folder = '.'
    N_limit = None
    num_iterations = 10000000
    proxy_iterations = 1
    train_proxy_every = 1
    batch_size = 32         # from StyleGAN2
    gradient_accumulations = 1  # adjust so that it fits on GPU
    lr = 2e-4               # from paper
    num_workers = 0
    include_proxy_reg_terms = False
    regularizer_strength = 0
    include_reg = True
    sparse_gp_for_regularisation = False
    tag = 'tester-celeba'
    optim_type = 'adam'

@ex.capture
def seed_all(seed, _log):
    """Seed all devices deterministically off of seed and somewhat
    independently."""
    msg = f"Seed: {seed}"
    _log.info(msg)

    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


@ex.capture
def update_run(network, _run):
    raise NotImplementedError()


@ex.automain
def main(batch_size, train_proxy_every, num_iterations,
         include_proxy_reg_terms, num_workers, N_limit,
         image_size, regularizer_strength, include_reg, lr, tag, optim_type,
         sparse_gp_for_regularisation, gradient_accumulations, celeba_folder,
         wandb_id, local_restore, _config, _log):

    torch.backends.cudnn.benchmark = True # set to False for reproducibility
    # set all seeds
    if wandb_id is None:
        seed_all()
    d = datetime.now().isoformat()
    job_suffix = ""
    if ARRAY_ID:
        job_suffix += f"_task_{ARRAY_ID}"
        job_suffix += f"_{ARRAY_TASK}"
    elif JOB_ID:
        job_suffix += f"_job_{JOB_ID}"
        if PROC_ID:
            job_suffix += f"_{PROC_ID}"

    if wandb_id is not None:
        id = wandb_id
    else:
        id = wandb.util.generate_id()
    wandb.init(project=PROJECT_NAME, name='stylgan2'+'celeba'+job_suffix,
               tags=[tag] if tag else None, dir="/tmp", group='celeba', id=id,
               resume='allow')
    if wandb_id is None:
        for key, value in _config.items():
            wandb.config[key] = value
    if wandb_id is not None:
        if wandb.run.id != wandb_id:
            raise ValueError("Wandb ids do not match!")
        run_path = wandb.run.path

    device = 'cuda'

    if device == 'cuda' and torch.cuda.is_available():
        _log.info("Running on GPU")
    else:
        _log.info("Running on CPU")

    assert batch_size % gradient_accumulations == 0
    size_of_batch_chunk = batch_size // gradient_accumulations
    model = Trainer(
        'name',
        results_dir = './results',
        models_dir = './models',
        batch_size = size_of_batch_chunk,
        gradient_accumulate_every = gradient_accumulations,  
        image_size = image_size,
        network_capacity = 16,  # default to not mess with network architecture
        reg_strength = regularizer_strength,
        include_reg = include_reg,
        sparse_gp_for_regularisation = sparse_gp_for_regularisation,
        transparent = False,
        lr = lr,   # we set this, but should use 2e-4
        ttur_mult = 1.5,
        num_workers = num_workers,
        save_every = 100000000000000000000,  # never save, we use wandb
        trunc_psi = 0.75,   # their own default
        fp16 = False,     # turn off a bunch of modifications
        cl_reg = False,
        fq_layers = [],
        fq_dict_size = 256,
        attn_layers = [],
        no_const = False,
        aug_prob = 0.,
        dataset_aug_prob = 0.,
        optim_type=optim_type,
    )

    transforms = T.Compose([
        T.Resize(image_size),
        T.CenterCrop(image_size),
        T.ToTensor(),
    ])
    dataset = torchvision.datasets.CelebA(root=celeba_folder, transform=transforms, download=True, split='all')
    def make_sup_unsup(dataset_class, *args, **kwargs):
        class UnsupDataset(dataset_class):
            def __getitem__(self, i):
                img, label = super().__getitem__(i)
                return img
        return dataset_class(*args, **kwargs), UnsupDataset(*args, **kwargs)
    sup_dataset, unsup_dataset = make_sup_unsup(torchvision.datasets.CelebA, root=celeba_folder, transform=transforms, download=True, split='all')
    model.set_data_src(unsup_dataset)
    class Generator():
        def sample(batch_shape=torch.Size([1])):
            return model.get_samples(n_samples=batch_shape[0], gen_type='normal')
        def eval():
            pass
        def parameters():
            return model.GAN.parameters()
    class GeneratorEMA(Generator):
        def sample(batch_shape=torch.Size([1])):
            return model.get_samples(n_samples=batch_shape[0], gen_type='ema')

    class CelebAWrapper(WandbWrapper):
        has_extra_neighbours_funcs = True
        extra_neighbours_type = 'embedded'
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.facenet = InceptionResnetV1(pretrained='vggface2').eval().cuda()
        def extra_process_neighbours(self, imgs):
            return self.facenet(F.interpolate(imgs, size=160))
        def extra_neighbour_distance(self, n1, n2):
            return ((n1.unsqueeze(0)-n2.unsqueeze(1))**2).sum(dim=-1)

    # now load checkpoint if necessary
    model.init_GAN()
    wandb_boi = CelebAWrapper(wandb, batch_size, num_workers, N_limit=N_limit, device=device)
    wandb_boi.set_neighbour_indices([6, 202560, 202594, 202597])
    if wandb_id is None:
        initial_iter = 0
        total_running_time = 0
        save_index = 1
    else:
        api = wandb.Api()
        run = api.run(f'{os.environ["WANDB_ENTITY"]}/{PROJECT_NAME}/{wandb_id}')
        save_indices = [r['save_index'] for r in run.scan_history() if 'save_index' in r]
        if len(save_indices) > 0:
            last_save_index = save_indices[-1]
            save_index = last_save_index + 1
        else:
            # to be backwards compatible
            raise Exception
            # print('loading old style run')
            # last_save_index = None
            # save_index = 1
        loaded = wandb_boi.load_objects(run_path, local_save=local_restore, with_index=last_save_index)
        model.load_state_dict(loaded['model'])
        initial_iter = loaded['iteration']
        total_running_time = loaded['running_time']
        model.steps = initial_iter

    for i in tqdm(range(initial_iter, num_iterations), mininterval=10., desc=f'stylegan', initial=initial_iter, total=num_iterations):

        at_logscale_point = (i==0) or (math.log(i, 2) % 1 == 0)
        end_epoch = (i % int(len(sup_dataset)/batch_size) == 0)
        end_fifth_epoch = (i % int(5*len(sup_dataset)/batch_size) == 0)

        # choose what to log
        do_big_log = end_fifth_epoch and (i != 0)
        do_little_log = end_epoch or at_logscale_point
        do_add_images_and_ckpt = do_little_log or do_big_log or (i%500 == 0)
        do_add_anything = do_little_log or do_big_log or (i%50 == 0)

        if do_big_log:
            wandb_boi.fid_score(Generator, sup_dataset, N=len(sup_dataset))
            wandb_boi.inception_score(Generator, N=50000)
            wandb_boi.swd_metric(Generator, sup_dataset, N=16384)
            wandb_boi.fid_score(GeneratorEMA, sup_dataset, N=len(sup_dataset), label='ema')
            wandb_boi.inception_score(GeneratorEMA, N=50000, label='ema')
            wandb_boi.swd_metric(GeneratorEMA, sup_dataset, N=16384, label='ema')
        if do_little_log:
            N = 1000
            wandb_boi.fid_score(Generator, sup_dataset, N=N)
            wandb_boi.inception_score(Generator, N=N)
            wandb_boi.swd_metric(Generator, sup_dataset, N=N)
            wandb_boi.fid_score(GeneratorEMA, sup_dataset, N=N, label='ema')
            wandb_boi.inception_score(GeneratorEMA, N=N, label='ema')
            wandb_boi.swd_metric(GeneratorEMA, sup_dataset, N=N, label='ema')
        if do_add_images_and_ckpt:
            normal_images = model.get_samples(3, 'normal', do_same=False)
            ema_images = model.get_samples(3, 'ema', do_same=False)
            fixed_normal = model.get_samples(3, 'normal', do_same=True)
            fixed_ema = model.get_samples(3, 'ema', do_same=True)
            images = [*normal_images, *ema_images, *fixed_normal,
                      *fixed_ema, unsup_dataset[i%len(unsup_dataset)]]
            image_names = [f'normal-{i}' for i in range(3)] +\
                          [f'ema-{i}' for i in range(3)] + \
                          [f'fixed-normal-{i}' for i in range(3)] + \
                          [f'fixed-ema-{i}' for i in range(3)] + ['data']
            wandb_boi.add_images(images, image_names, iteration=i)
            wandb_boi.track_summary_stats(Generator, sup_dataset)

        pre_step_time = time()
        retry_call(model.train, tries=3, exceptions=NanException)
        prev_running_time = total_running_time
        total_running_time += time() - pre_step_time

        if i % 50 == 0:
            model.print_log()

        pot_results_dir = Path('/results')
        if pot_results_dir.exists():
            local_saves = [None, pot_results_dir]
            image_dir = pot_results_dir
        else:
            local_saves = [None, os.path.join(os.environ['SCRATCH'],
                                                'gan-ckpts/celeba')]
            image_dir = os.path.join(os.environ['SCRATCH'], 'gan-images/celeba')
        if do_add_images_and_ckpt:
            wandb.log({'save_index': save_index}, commit=False)
            for local_save in local_saves:
                wandb_boi.save_objects(
                    [model.get_state_dict(), i+1, total_running_time],
                    ['model', 'iteration', 'running_time'],
                    local_save=local_save, with_index=save_index)

        if do_add_anything:
            reg = 0 if regularizer_strength == 0 else model.prev_regulariser
            # wandb_boi.wandb.log({'regulariser': reg, 'iterations': i}, commit=False)
            wandb.log({'advas_normalizer': model.reg_normalizer}, commit=False)
            wandb_boi.track_loss(loss=model.g_loss+reg, label='generator', part_losses=(model.g_loss, reg))
            wandb_boi.track_loss(loss=model.d_loss, label='proxy')
            wandb_boi.log(iteration=i, running_time=total_running_time)  # note running time is one iteration ahead

        passed_5h_mark = (total_running_time % 18000) < (prev_running_time % 18000)
        if passed_5h_mark:
            n_hours = total_running_time // 3600
            evaluation.save_eval_images(
                image_dir,
                f'interp_{wandb.run.id}_{n_hours}h', model, 10, interp=True)
            evaluation.save_eval_images(
                image_dir,
                f'random_{wandb.run.id}_{n_hours}h', model, 100, interp=False)
        if i % (5*6331) == 0 and i > 0:
            evaluation.save_eval_images(
                image_dir,
                f'interp_{wandb.run.id}_{i//6331}epochs', model, 10, interp=True)
            evaluation.save_eval_images(
                image_dir,
                f'random_{wandb.run.id}_{i//6331}epochs', model, 100, interp=False)

    msg = "\nFinished Experiment!\n===================="
    _log.info(msg)
