import os
import random
from datetime import datetime

import numpy as np
import torch
import torch.optim as optim
from sacred import Experiment
from termcolor import colored

import wandb
import advas
from advas.utils import WandbWrapper
from experiment_loader import data_loader, experiment_ingredient
from objective_dispatch import objective_dispatch, objective_ingredient
from generator_dispatch import generator_dispatch, generator_ingredient
from proxy_method import proxy_method, proxy_method_ingredient

PROJECT_NAME = "gan-regularizer"
JOB_ID = os.getenv('SLURM_JOB_ID', "")
PROC_ID = os.getenv('SLURM_PROCID', "")
ARRAY_ID = os.getenv('SLURM_ARRAY_JOB_ID', "")
ARRAY_TASK = os.getenv('SLURM_ARRAY_TASK_ID', "")

ex = Experiment(PROJECT_NAME, ingredients=[experiment_ingredient,
                                           objective_ingredient,
                                           generator_ingredient,
                                           proxy_method_ingredient])

# For baseline code see: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix


@ex.named_config
def cuda():
    cuda = True


@ex.config
def cfg():

    num_iterations = 1000
    proxy_iterations = 1
    train_proxy_every = 1

    batch_size = 128

    load_models = False
    save_models = False
    save_data = False

    optimizer_type = "ADAM"
    lr = 1e-4

    num_workers = 0

    cuda = False
    wandb_id = None
    tag='tester'
    n_metric_limit = None
    swd_limit = 10000

@ex.capture
def seed_all(seed, _log):
    """Seed all devices deterministically off of seed and somewhat
    independently."""
    msg = f"Seed: {seed}"
    _log.info(msg)

    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


@ex.capture
def update_run(network, _run):
    raise NotImplementedError()


@ex.capture
def optimizer_dispatch(p, optimizer_type, lr, _log):
    if optimizer_type == "ADAM":
        opt = optim.Adam(p.parameters(), lr=lr, betas=(0.5, 0.999))
    elif optimizer_type == "RMSPROP":
        opt = optim.RMSprop(p.parameters(), lr=lr)
    elif optimizer_type == "SGD":
        opt = optim.SGD(p.parameters(), lr=lr)
    else:
        raise NotImplementedError("Unknown optimizer")
    _log.info(f"Running generator with optimizer: {optimizer_type}")
    return opt


@ex.automain
def main(batch_size, train_proxy_every, num_iterations, proxy_iterations,
         num_workers, save_data, save_models, load_models, cuda,
         n_metric_limit, wandb_id, tag, swd_limit, _config, _log):

    # set all seeds
    seed_all()
    # uncomment below for reproducibility (https://pytorch.org/docs/stable/notes/randomness.html)
    # torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True  # set to False for reproducibility
    d = datetime.now().isoformat()

    job_suffix = ""
    if ARRAY_ID:
        job_suffix += f"_task_{ARRAY_ID}"
        job_suffix += f"_{ARRAY_TASK}"
    elif JOB_ID:
        job_suffix += f"_job_{JOB_ID}"
        if PROC_ID:
            job_suffix += f"_{PROC_ID}"

    datasets, data_shape, exp_name = data_loader()

    if wandb_id is not None:
        id = wandb_id
    else:
        id = wandb.util.generate_id()
    wandb.init(project=PROJECT_NAME, name=exp_name+job_suffix, dir="/tmp",
               group=exp_name, id=id, resume='allow', tags=[tag] if tag else None)
    if wandb_id is None:
        for key, value in _config.items():
            wandb.config[key] = value

    device = 'cuda' if cuda else 'cpu'

    if device == 'cuda' and torch.cuda.is_available():
        _log.info("Running on GPU")
    else:
        _log.info("Running on CPU")

    wandbwrapper = WandbWrapper(wandb, batch_size=batch_size,
                                num_workers=num_workers, device=device,
                                N_limit=n_metric_limit, swd_limit=swd_limit)

    if wandb_id is not None:
        if wandb.run.id != wandb_id:
            raise ValueError("Wandb ids do not match!")
        run_path = wandb.run.path
        loaded_objects = wandbwrapper.load_objects(run_path)
    else:
        loaded_objects = None

    p = generator_dispatch(data_shape)
    proxy_model, proxy_optimizer = proxy_method(data_shape)

    optimizer = optimizer_dispatch(p)
    objective, normtype = objective_dispatch()

    gan_trainer = advas.TrainGan(p, datasets, optimizer, batch_size,
                                 proxy_model, proxy_optimizer, batch_size,
                                 objective, num_workers=num_workers,
                                 normtype=normtype, device=device,
                                 wandbwrapper=wandbwrapper)

    gan_trainer.load_objects(loaded_objects)

    # train model
    output = gan_trainer.train(num_iterations={'generator': num_iterations,
                                               'proxy': proxy_iterations},
                               train_proxy_every=train_proxy_every)

    p, proxy_model = output

    msg = "\nFinished Experiment!\n===================="
    _log.info(msg)
