import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from absl import app, flags
from ml_collections.config_flags import config_flags
from tqdm.auto import tqdm

import gen_neg_toy
import wandb
from gen_neg_toy import (data, dispatch_model, dispatch_model_from_path,
                         evaluation, script_utils)
from gen_neg_toy.configs._default import (add_distill_configs,
                                          get_default_configs)
from gen_neg_toy.loss import dispatch_loss
from gen_neg_toy.ng_utils import compute_infraction, get_dataset_min_distance
from gen_neg_toy.utils import expand_tensor_dims_as, infinite_loader, logging
from gen_neg_toy.utils.random import set_random_seed

logging.support_unobserve()


FLAGS = flags.FLAGS
config_dict = get_default_configs()
add_distill_configs(config_dict)
config_flags.DEFINE_config_dict("config", config_dict, "Training configuration.")
flags.DEFINE_list("tags", [], "Tags to add to the run.")
flags.DEFINE_string("wandb_name", None, "wandb name.")
# flags.mark_flags_as_required(["config.distill.checkpoint"])


@torch.no_grad()
def evaluate(config, train_set, test_set, val_loader, model, n_samples):
    t_0 = time.time()

    samples, nfe = script_utils.draw_samples(
        model,
        n_samples,
        device=config.device,
        num_steps=config.sampling.steps,
        S_churn=config.sampling.S_churn,
    )

    sampling_time = time.time() - t_0
    log_dict_vis = {}
    log_dict = evaluation.infraction(samples)
    log_dict_vis["samples"] = wandb.Image(evaluation.visualize_hist(samples))
    log_dict_vis["samples_scatter"] = wandb.Image(evaluation.visualize_scatter(samples))
    log_dict["sampling_time"] = sampling_time
    log_dict["nfe"] = nfe

    ## ELBO ##
    t_0 = time.time()
    elbo = []
    for x0, validity in val_loader:
        elbo.append(
            script_utils.elbo(
                model, (x0, validity), config.device, num_steps=config.sampling.steps
            )
        )
    elbo = np.mean(elbo)
    elbo_time = time.time() - t_0
    log_dict["elbo"] = elbo
    log_dict["elbo_time"] = elbo_time

    ## ELBO ##
    t_0 = time.time()
    x0, validity = zip(*[train_set[i] for i in range(min(1000, len(train_set)))])
    x0 = torch.as_tensor(np.stack(x0)).float()
    validity = torch.as_tensor(validity).float()
    elbo = script_utils.elbo(
        model, (x0, validity), config.device, num_steps=config.sampling.steps
    )
    log_dict["elbo_train"] = elbo

    plt.close("all")
    log_dict = {f"eval/{k}": v for k, v in log_dict.items()}
    log_dict.update({f"vis/{k}": v for k, v in log_dict_vis.items()})
    return log_dict


def train(config, iter_start=0):
    if config.seed is not None:
        set_random_seed(config.seed)

    log_handler = logging.LoggingHandler()

    checkpoints_root = Path("results/checkpoints") / wandb.run.id
    assert (
        not checkpoints_root.exists()
    ), f"Checkpoint directory {checkpoints_root} already exists."

    ## Load the dataset ##
    train_set, test_set = data.get_datasets(config.data)
    if config.distill.dataset is not None and config.distill.dataset.endswith(".npy"):
        train_set = data.SyntheticDataset(
            config.distill.dataset, config.distill.train_set_size, labels=1
        )
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=config.training.batch_size, shuffle=True, pin_memory=True
    )
    val_loader = torch.utils.data.DataLoader(
        test_set, batch_size=config.training.batch_size, shuffle=False
    )
    inf_train_loader = infinite_loader(train_loader)
    print(
        f"Training set min distance to boundries = {get_dataset_min_distance(train_loader)}"
    )
    print(
        f"Validation set min distance to boundries = {get_dataset_min_distance(val_loader)}"
    )

    ## Initialize the model and SDE ##
    model = dispatch_model(config.model)
    model = model.to(config.device)

    ## Load the teacher model ##
    teacher_model, teacher_model_config = dispatch_model_from_path(
        config.distill.checkpoint,
        strict=(config.distill.classifier is None),
        classifier=config.distill.classifier,
    )
    teacher_model = teacher_model.to(config.device)
    teacher_model.eval()
    teacher_model.requires_grad_(False)

    ## Initialize the optimizer ##
    optimizer = getattr(torch.optim, config.optim.optimizer)(
        [v for k, v in model.named_parameters() if not k.startswith("classifier.")],
        lr=config.optim.lr,
    )
    loss_fn = dispatch_loss(config.loss)

    if config.training.init_checkpoint is not None:
        print(f"Initializing model from checkpoint {config.training.init_checkpoint}")
        model.load(config.training.init_checkpoint, strict=False)
        optimizer_state_dict = torch.load(
            config.training.init_checkpoint, map_location=lambda storage, loc: storage
        )["optimizer_state_dict"]
        optimizer.load_state_dict(optimizer_state_dict)

    ## Log some info ##
    wandb.log({"parameters": model.count_parameters()}, commit=False)
    dataset = np.array(
        [
            train_set[i][0]
            for i in np.random.RandomState(123).choice(
                range(len(train_set)), size=10000
            )
        ]
    )
    neg_dataset = np.array([x for x, validity in train_set if validity == 0])
    wandb.log(
        {"vis/dataset": wandb.Image(evaluation.visualize_hist(dataset))}, commit=False
    )
    plt.close("all")
    wandb.log(
        {"vis/dataset_scatter": wandb.Image(evaluation.visualize_scatter(dataset))},
        commit=False,
    )
    plt.close("all")
    if len(neg_dataset) > 0:
        wandb.log(
            {
                "vis/neg_dataset": wandb.Image(
                    evaluation.visualize_scatter(neg_dataset[:10000])
                )
            },
            commit=False,
        )
        plt.close("all")

    ## Training loop ##
    for iteration in tqdm(
        range(iter_start, iter_start + config.training.n_iters), mininterval=10.0
    ):
        log_dict = {}
        eval_log_dict = {}
        # Prepare data
        x0, validity = next(inf_train_loader)
        x0 = x0.float()
        validity = validity.float() * 2 - 1  # Validity is now \in {-1, 1}

        model_kwargs = {}
        x0 = x0.to(config.device)

        # Forward-Backward passes
        optimizer.zero_grad()
        if np.random.rand() < config.distill.mix_gt:
            loss = loss_fn(model, x0, model_kwargs=model_kwargs)
        else:
            sigma = loss_fn.sample_sigma(len(x0), x0.device)
            weight = expand_tensor_dims_as(loss_fn.compute_weight(sigma), x0)
            n = torch.randn_like(x0) * expand_tensor_dims_as(sigma, x0)
            xt = x0 + n
            student_out = model(xt, sigma, **model_kwargs)
            with torch.no_grad():
                teacher_out = teacher_model(xt, sigma, **model_kwargs).detach()
            loss = weight * ((student_out - teacher_out) ** 2)
        loss = loss.view(len(loss), -1).sum(dim=-1).mean()
        loss.backward()
        if config.training.clip_grad_norm is not None:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.parameters(), config.training.clip_grad_norm
            ).item()
            log_dict["grad_norm"] = grad_norm
        optimizer.step()

        log_dict["loss"] = loss.item()
        log_handler(log_dict)
        # Logging images
        if (
            iteration % config.training.eval_interval == 0  # and iteration > iter_start
        ) or iteration == iter_start + config.training.n_iters - 1:
            try:
                eval_log_dict = evaluate(
                    config,
                    train_set,
                    test_set,
                    val_loader,
                    model,
                    n_samples=config.training.n_eval_samples,
                )
                eval_log_dict["iteration"] = iteration
                wandb.log(eval_log_dict, commit=False)
            except Exception as e:
                # raise e
                print(f"Ignored exception: {e}")
        # Training stats and logging to wandb
        if iteration % config.training.log_interval == 0:
            log_dict = log_handler.flush()
            log_dict["iteration"] = iteration
            wandb.log(log_dict)
        # Save the model
        if (
            config.training.save_interval is not None
            and iteration % config.training.save_interval == 0
            and iteration > iter_start
        ):
            checkpoints_root.mkdir(parents=True, exist_ok=True)
            model.save(checkpoints_root / f"iter_{iteration}.pt", config=config)
    # Save the final model
    checkpoints_root.mkdir(parents=True, exist_ok=True)
    model.save(
        checkpoints_root / f"iter_{iter_start + config.training.n_iters}.pt",
        config=config,
        optimizer_state_dict=optimizer.state_dict(),
    )

    return eval_log_dict


def main(argv):
    logging.init(config=FLAGS.config.to_dict(), tags=FLAGS.tags, name=FLAGS.wandb_name)
    train(FLAGS.config)
    wandb.log({})


if __name__ == "__main__":
    app.run(main)
