import itertools
from argparse import Namespace
from pathlib import Path

import matplotlib.pyplot as plt
import ml_collections
import numpy as np
import torch
from absl import app, flags
from ml_collections.config_flags import config_flags
from torch.utils.data import Dataset
from tqdm.auto import tqdm

import wandb
from gen_neg_toy import data
from gen_neg_toy.classifier import dispatch_model, dispatch_model_from_path
from gen_neg_toy.configs._default import get_default_configs
from gen_neg_toy.loss import dispatch_loss
from gen_neg_toy.ng_utils import compute_infraction
from gen_neg_toy.utils import expand_tensor_dims_as, logging, infinite_loader
from gen_neg_toy.utils.random import RNG

logging.support_unobserve()

FLAGS = flags.FLAGS
config_dict = get_default_configs(config_names=["data"])
config_dict.data.dataset = "uniform"
config_dict.training = ml_collections.ConfigDict()
config_dict.training.batch_size = 8192 # Batch size
config_dict.training.n_iters = 50000 # Number of training iterations
config_dict.training.log_interval = 1000 # Logging interval
config_dict.training.eval_interval = 10000 # Evaluation interval
config_dict.training.save_interval = ml_collections.config_dict.placeholder(int)
config_dict.training.bridge_loss = 0 # If non-zero, use bridge loss
config_dict.training.baseline_classifier = ml_collections.config_dict.placeholder(str) # Path to the previous iteration classifier, if any
config_dict.training.alpha = ml_collections.config_dict.placeholder(float) # The probability of model generating clean examples (1 - infraction rate)
config_dict.training.out_dir = ml_collections.config_dict.placeholder(str) # Path to store the checkpoints at. If not given, uses wandb id to make a unique path
config_dict.training.edm_sigma = 0 # If non-zero, uses EDM's continuous sigma sampling procedure. If zero, uses VE sigma sampling procedure.
config_dict.model = ml_collections.ConfigDict()
config_dict.model.precond = "edm_simple" # Preconditioner to use
config_dict.model.classifier_param = "sigmoid" # Choices: ["rescaled", "sigmoid", "softplus"]
config_flags.DEFINE_config_dict(
    "config", config_dict, "Training configuration.")
flags.DEFINE_list("tags", [], "Tags to add to the run.")


class UniformDataset(Dataset):
    def __init__(self, size=None):
        self.size = size
        self.width = 2.5
        if self.size is not None:
            self.data = torch.rand((size, 2)) * self.width - self.width / 2
        else:
            self.data = None
            self.size = 100000

    def __getitem__(self, idx):
        if self.data is not None:
            return self.data[idx]
        return torch.rand((2,)) * self.width - self.width / 2

    def __len__(self):
        return self.size

    @property
    def alpha(self):
        return 1 / (self.width ** 2)


def main(argv):
    config = FLAGS.config
    assert config.training.alpha is None or (config.training.alpha > 0 and config.training.alpha < 1)
    logging.init(config=config.to_dict(), tags=FLAGS.tags)
    log_handler = logging.LoggingHandler()
    if config.training.out_dir is None:
        checkpoints_root = Path("results/checkpoints_classifier") / wandb.run.id
    else:
        checkpoints_root = Path(config.training.out_dir)
        assert not checkpoints_root.exists(), f"Output directory {checkpoints_root} already exists"

    classifier = dispatch_model(config.model).to(config.device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=3e-3)

    if config.data.dataset == "uniform":
        train_set = data.ValidatedDataset(UniformDataset(), cache_labels=False)
        #test_set = data.ValidatedDataset(UniformDataset(size=1024))
        train_set_alpha = train_set.alpha
    elif config.data.dataset.endswith(".npy"):
        train_set = data.SyntheticDataset(config.data.dataset, config.data.train_set_size, labels=1)
        if config.data.neg_dataset is not None:
            # Add the negative examples
            neg_dataset = data.SyntheticDataset(config.data.neg_dataset, config.data.neg_dataset_size, labels=0)
            neg_train_examples_cnt = len(neg_dataset)
            train_set = torch.utils.data.ConcatDataset([train_set, neg_dataset])
            wandb.log({"neg_dataset_size": neg_train_examples_cnt}, commit=False)
            train_set_alpha = 1 - neg_train_examples_cnt / len(train_set)

    else:
        train_set, _ = data.get_datasets(
            config.data.dataset, train_set_size=config.data.train_set_size
        )
        if config.data.neg_dataset is not None:
            # Add the negative examples
            old_train_set_size = len(train_set)
            train_set = data.merge_neg_dataset(
                train_set, config.data.neg_dataset, config.data.neg_dataset_size
            )
            neg_train_examples_cnt = len(train_set) - old_train_set_size
            wandb.log({"neg_dataset_size": neg_train_examples_cnt}, commit=False)
            train_set_alpha = 1 - neg_train_examples_cnt / len(train_set)
    print(f"Training set alpha = {train_set_alpha}")
    # split into train and test using torch.utils.data.random_split
    with RNG(123):
        val_set_size = max(min(len(train_set) // 10, 1000), 1)
        wandb.log({"val_set_size": val_set_size}, commit=False)
        train_set, val_set = torch.utils.data.random_split(
            train_set, [len(train_set) - val_set_size, val_set_size]
        )

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=config.training.batch_size, shuffle=True,
        pin_memory=True,
    )
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=config.training.batch_size, shuffle=False
    )
    train_loader = infinite_loader(train_loader)

    edm_loss_fn = dispatch_loss(Namespace(name="edm"))

    wandb.log({"parameters": classifier.count_parameters()}, commit=False)
    for iteration in range(config.training.n_iters):
        x, target = next(train_loader)
        log_dict = {}

        optimizer.zero_grad()
        target = target.float().to(config.device)
        x = x.float().to(config.device)
        if config.training.edm_sigma == 0:
            sigma = torch.rand(len(x), device=config.device) * (np.log(80) - np.log(0.002)) + np.log(0.002)
            sigma = sigma.exp()
        else:
            sigma = edm_loss_fn.sample_sigma(len(x), config.device)
        n = torch.randn_like(x) * expand_tensor_dims_as(sigma, x)
        xt = x + n
        if config.training.bridge_loss != 0:
            # Compute bridge target
            n = torch.randn_like(x) * expand_tensor_dims_as(sigma, x)
            target = 1 - compute_infraction(xt + n).float()
        pred = classifier(xt, sigma).flatten()
        loss = classifier.criterion(pred, target.to(torch.float64), sigma, reduction="none")
        if config.training.alpha is not None:
            loss[target == 1] *= config.training.alpha / (train_set_alpha * 2)
            loss[target == 0] *= (1 - config.training.alpha) / ((1 - train_set_alpha) * 2)
        loss = loss.mean()
        loss.backward()
        optimizer.step()

        log_dict["loss"] = loss.item()
        log_handler(log_dict)

        if iteration % config.training.log_interval == 0:
            acc = (
                ((classifier.out_to_p(pred, sigma) > 0.5) == target.bool())
                .float()
                .mean()
            )
            print(f"#{iteration:<10}: {loss.item():<25} | {acc.item():<25}")
            log_dict = log_handler.flush()
            log_dict["acc"] = acc.item()
            log_dict["iteration"] = iteration
            wandb.log(log_dict)

        if iteration % config.training.eval_interval == 0:
            sigma_vals = np.logspace(np.log10(0.002), np.log10(80), 100)
            acc_vals = []
            with torch.no_grad():
                for sigma_scalar in sigma_vals:
                    acc = []
                    for x, target in val_loader:
                        target = target.float().to(config.device)
                        x = x.float().to(config.device)
                        sigma = x.new_ones(len(x)) * sigma_scalar
                        n = torch.randn_like(x) * expand_tensor_dims_as(sigma, x)
                        xt = x + n
                        if config.training.bridge_loss != 0:
                            # Compute bridge target
                            n = torch.randn_like(x) * expand_tensor_dims_as(sigma, x)
                            target = 1 - compute_infraction(xt + n).float()
                        pred = classifier(xt, sigma).flatten()
                        acc.append(((classifier.out_to_p(pred, sigma) > 0.5) == target.bool()).float().mean().item())
                    acc_vals.append(np.mean(acc))
            fig = plt.figure()
            fig.suptitle(f"iteration = {iteration}")
            plt.plot(sigma_vals, acc_vals)
            plt.xscale("log")
            plt.grid("on")
            quart_len = len(acc_vals) // 4
            wandb.log({"eval/plot": wandb.Image(fig),
                       "eval/acc": np.mean(acc_vals),
                       "eval/acc_q0": np.mean(acc_vals[:quart_len]),
                       "eval/acc_q1": np.mean(acc_vals[quart_len:quart_len*2]),
                       "eval/acc_q2": np.mean(acc_vals[quart_len*2:quart_len*3]),
                       "eval/acc_q3": np.mean(acc_vals[quart_len*3:]),},
                      commit=False)
            plt.close("all")
        if iteration > 0 and config.training.save_interval is not None and iteration % config.training.save_interval == 0:
            checkpoints_root.mkdir(parents=True, exist_ok=True)
            classifier.save(checkpoints_root / f"iter_{iteration}.pt", config=config)
    checkpoints_root.mkdir(parents=True, exist_ok=True)
    classifier.save(checkpoints_root / f"final_{iteration + 1}.pt", config=config)


if __name__ == "__main__":
    app.run(main)
