import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from absl import app, flags
from ml_collections.config_flags import config_flags
from tqdm.auto import tqdm

import gen_neg_toy
import wandb
from gen_neg_toy import data, dispatch_model, evaluation, script_utils
from gen_neg_toy.configs._default import get_default_configs
from gen_neg_toy.loss import dispatch_loss
from gen_neg_toy.ng_utils import get_dataset_min_distance
from gen_neg_toy.utils import infinite_loader, logging
from gen_neg_toy.utils.random import RNG, set_random_seed

logging.support_unobserve()

FLAGS = flags.FLAGS
config_flags.DEFINE_config_dict("config", get_default_configs())
flags.DEFINE_list("tags", [], "Tags to add to the run.")
flags.DEFINE_string("wandb_name", None, "wandb name.")


@torch.no_grad()
def evaluate(config, train_set, test_set, val_loader, model, n_samples):
    t_0 = time.time()

    samples, nfe = script_utils.draw_samples(
        model,
        n_samples,
        device=config.device,
        num_steps=config.sampling.steps,
        S_churn=config.sampling.S_churn,
    )

    sampling_time = time.time() - t_0
    log_dict_vis = {}
    log_dict = evaluation.infraction(samples)
    log_dict_vis["samples"] = wandb.Image(evaluation.visualize_hist(samples))
    log_dict_vis["samples_scatter"] = wandb.Image(evaluation.visualize_scatter(samples))
    log_dict["sampling_time"] = sampling_time
    log_dict["nfe"] = nfe

    ## ELBO ##
    t_0 = time.time()
    elbo = []
    for x0, validity in val_loader:
        elbo.append(
            script_utils.elbo(
                model, (x0, validity), config.device, num_steps=config.sampling.steps
            )
        )
    elbo = np.mean(elbo)
    elbo_time = time.time() - t_0
    log_dict["elbo"] = elbo
    log_dict["elbo_time"] = elbo_time

    ## ELBO ##
    t_0 = time.time()
    x0, validity = zip(*[train_set[i] for i in range(min(1000, len(train_set)))])
    x0 = torch.as_tensor(np.stack(x0)).float()
    validity = torch.as_tensor(validity).float()
    elbo = script_utils.elbo(
        model, (x0, validity), config.device, num_steps=config.sampling.steps
    )
    log_dict["elbo_train"] = elbo

    plt.close("all")
    log_dict = {f"eval/{k}": v for k, v in log_dict.items()}
    log_dict.update({f"vis/{k}": v for k, v in log_dict_vis.items()})
    return log_dict


def train(config, iter_start=0):
    if config.seed is not None:
        set_random_seed(config.seed)

    log_handler = logging.LoggingHandler()

    checkpoints_root = Path("results/checkpoints") / wandb.run.id
    assert (
        not checkpoints_root.exists()
    ), f"Checkpoint directory {checkpoints_root} already exists."

    ## Load the dataset ##
    train_set, test_set = data.get_datasets(config.data)
    neg_train_examples_cnt = 0
    if config.data.neg_dataset is not None:
        # Add the negative examples
        assert (
            config.data.dataset == "checkerboard"
        ), f"Negative dataset is only supported for the checkerboard dataset for now."
        old_train_set_size = len(train_set)
        train_set = data.merge_neg_dataset(
            train_set, config.data.neg_dataset, config.data.neg_dataset_size
        )
        neg_train_examples_cnt = len(train_set) - old_train_set_size
        wandb.log({"neg_dataset_size": neg_train_examples_cnt}, commit=False)
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=config.training.batch_size, shuffle=True, pin_memory=True
    )
    val_loader = torch.utils.data.DataLoader(
        test_set, batch_size=config.training.batch_size, shuffle=False
    )
    inf_train_loader = infinite_loader(train_loader)
    print(f"Training set min distance to boundries = {get_dataset_min_distance(train_loader)}")
    print(f"Validation set min distance to boundries = {get_dataset_min_distance(val_loader)}")

    ## Initialize the model and SDE ##
    model = dispatch_model(config.model)
    model = model.to(config.device)

    ## Initialize the optimizer ##
    optimizer = getattr(torch.optim, config.optim.optimizer)(
        [v for k,v in model.named_parameters() if not k.startswith("classifier.")], lr=config.optim.lr
    )
    loss_fn = dispatch_loss(config.loss)

    if config.training.init_checkpoint is not None:
        print(f"Initializing model from checkpoint {config.training.init_checkpoint}")
        model.load(config.training.init_checkpoint, strict=False)
        optimizer_state_dict = torch.load(config.training.init_checkpoint, map_location=lambda storage, loc: storage)["optimizer_state_dict"]
        optimizer.load_state_dict(optimizer_state_dict)

    ## Log some info ##
    wandb.log({"parameters": model.count_parameters()}, commit=False)
    dataset = np.array([train_set[i][0]
            for i in np.random.RandomState(123).choice(range(len(train_set)), size=10000)])
    neg_dataset = np.array([x for x, validity in train_set if validity == 0])
    wandb.log(
        {"vis/dataset": wandb.Image(evaluation.visualize_hist(dataset))}, commit=False
    )
    plt.close("all")
    wandb.log(
        {"vis/dataset_scatter": wandb.Image(evaluation.visualize_scatter(dataset))},
        commit=False,
    )
    plt.close("all")
    if len(neg_dataset) > 0:
        wandb.log(
            {
                "vis/neg_dataset": wandb.Image(
                    evaluation.visualize_scatter(neg_dataset[:10000])
                )
            },
            commit=False,
        )
        plt.close("all")

    ## Training loop ##
    for iteration in tqdm(
        range(iter_start, iter_start + config.training.n_iters), mininterval=10.0
    ):
        log_dict = {}
        eval_log_dict = {}
        # Prepare data
        x0, validity = next(inf_train_loader)
        x0 = x0.float()
        validity = validity.float() * 2 - 1  # Validity is now \in {-1, 1}

        model_kwargs = {}

        x0 = x0.to(config.device)

        # Forward-Backward passes
        optimizer.zero_grad()
        loss = loss_fn(model, x0, model_kwargs=model_kwargs)
        loss = loss.view(len(loss), -1).sum(dim=-1).mean()
        loss.backward()
        if config.training.clip_grad_norm is not None:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.parameters(), config.training.clip_grad_norm
            ).item()
            log_dict["grad_norm"] = grad_norm
        optimizer.step()

        log_dict["loss"] = loss.item()
        log_handler(log_dict)
        # Logging images
        if (
            iteration % config.training.eval_interval == 0 # and iteration > iter_start
        ) or iteration == iter_start + config.training.n_iters - 1:
            try:
                eval_log_dict = evaluate(
                    config,
                    train_set,
                    test_set,
                    val_loader,
                    model,
                    n_samples=config.training.n_eval_samples,
                )
                eval_log_dict["iteration"] = iteration
                wandb.log(eval_log_dict, commit=False)
            except Exception as e:
                #raise e
                print(f"Ignored exception: {e}")
        # Training stats and logging to wandb
        if iteration % config.training.log_interval == 0:
            log_dict = log_handler.flush()
            log_dict["iteration"] = iteration
            wandb.log(log_dict)
        # Save the model
        if (
            config.training.save_interval is not None
            and iteration % config.training.save_interval == 0
            and iteration > iter_start
        ):
            checkpoints_root.mkdir(parents=True, exist_ok=True)
            model.save(checkpoints_root / f"iter_{iteration}.pt", config=config)
    # Save the final model
    checkpoints_root.mkdir(parents=True, exist_ok=True)
    model.save(
        checkpoints_root / f"iter_{iter_start + config.training.n_iters}.pt",
        config=config,
        optimizer_state_dict=optimizer.state_dict(),
    )

    return eval_log_dict


def main(argv):
    logging.init(config=FLAGS.config.to_dict(), tags=FLAGS.tags, name=FLAGS.wandb_name)
    train(FLAGS.config)
    wandb.log({})


if __name__ == "__main__":
    app.run(main)
