import json
from functools import partial
from pathlib import Path

import numpy as np
import torch
from absl import app, flags
from gen_neg_toy.utils.random import set_random_seed
from ml_collections.config_dict import placeholder
from ml_collections.config_flags import config_flags
from tqdm import tqdm

import gen_neg_toy
from gen_neg_toy import data, dispatch_model_from_path, script_utils
from gen_neg_toy.configs._default import get_default_configs
from gen_neg_toy.ng_utils import compute_infraction, compute_distance_to_boundries
from gen_neg_toy.sampling import edm_sampler
from gen_neg_toy.sde_lib import VPSDE

FLAGS = flags.FLAGS
config_dict = get_default_configs(config_names=["sampling"])
config_dict.checkpoint = placeholder(str)
config_dict.n_samples = 1000000  # Number of samples to draw
#config_dict.seed = 321  # Random seed
config_dict.classifier = placeholder(str) # Path to a classifier checkpoint for guidance
config_dict.classifier_weight = placeholder(
    float
)  # Classifier guidance weight. If not given, will not use classifier guidance.
config_dict.out = placeholder(
    str
)  # Output file name. It will be located under results/neg_dataset/
config_dict.base_dataset = placeholder(
    str
)  # Base dataset to add new samples to. This base dataset file remains unchanged. The new dataset is created in a separate file.
config_dict.base_dataset_subsample = placeholder(
    float
)  # If given, will subsample the base dataset by this factor.
config_dict.slack = placeholder(float)  # Slackness of the infraction
config_flags.DEFINE_config_dict("config", config_dict, "Training configuration.")
flags.mark_flags_as_required(["config.checkpoint", "config.out"])


def main(argv):
    config = FLAGS.config
    #set_random_seed(config.seed)

    out_dir = Path("results/neg_dataset/")
    neg_dataset_path = out_dir / f"{config.out}.npy"
    pos_dataset_path = out_dir / f"{config.out}_pos.npy"
    out_config_path = out_dir / f"{config.out}.json"
    # Prepare the base samples (if required)
    base_samples = None
    if config.base_dataset is not None:
        base_dataset_path = out_dir / f"{config.base_dataset}.npy"
        base_samples = np.load(base_dataset_path)
        if config.base_dataset_subsample is not None:
            perm = np.random.permutation(len(base_samples))
            n_base_samples = int(config.base_dataset_subsample * len(base_samples))
            if n_base_samples > 0:
                base_samples = base_samples[perm[:n_base_samples]]
            else:
                base_samples = None

    # Initialize the model
    classifier = None
    if config.classifier is not None:
        classifier = config.classifier.split(",")
    model, model_config = dispatch_model_from_path(config.checkpoint, strict=(classifier is None), classifier=classifier)
    model = model.to(config.device)
    model.eval()
    model.requires_grad_(False)

    # Draw samples
    max_batch_size = 10000
    neg_dataset = []
    neg_dataset_size = 0
    pos_dataset_size = 0
    pos_dataset = []
    with tqdm(total=config.n_samples, desc="Generating negative samples") as pbar:
        while neg_dataset_size < config.n_samples or pos_dataset_size < config.n_samples:
            samples, nfe = script_utils.draw_samples(
                model,
                n_samples=max_batch_size,
                device=config.device,
                num_steps=config.sampling.steps,
                S_churn=config.sampling.S_churn,
            )
            infraction = compute_infraction(samples, slack=config.slack)
            if sum(~infraction) > 0 and pos_dataset_size < config.n_samples:
                samples_pos = samples[~infraction]
                pos_dataset.append(samples_pos.cpu().numpy())
                pos_dataset_size += len(samples_pos)
            if sum(infraction) > 0 and neg_dataset_size < config.n_samples:
                samples_neg = samples[infraction]
                neg_dataset.append(samples_neg.cpu().numpy())
                pbar.update(min(len(samples_neg), config.n_samples - neg_dataset_size))
                neg_dataset_size += len(samples_neg)
            else:
                pbar.update(0)
    neg_dataset = np.concatenate(neg_dataset, axis=0)[: config.n_samples]
    pos_dataset = np.concatenate(pos_dataset, axis=0)[: config.n_samples]
    new_samples_cnt = len(neg_dataset)
    if base_samples is not None:
        neg_dataset = np.concatenate([base_samples, neg_dataset], axis=0)
    neg_dataset_path.parent.mkdir(exist_ok=True, parents=True)
    # Save the datasets
    with open(neg_dataset_path, "wb") as f:
        np.save(f, neg_dataset)
    with open(pos_dataset_path, "wb") as f:
        np.save(f, pos_dataset)
    print(
        f"Saved the generated negative examples at {neg_dataset_path}. (Total negative examples: {len(neg_dataset)}. New examples: {new_samples_cnt})"
    )
    print(f"Saved the generative positive examples at {pos_dataset_path}. (Total positive examples: {len(pos_dataset)})")
    # Save the config
    with open(out_config_path, "w") as f:
        config_d = config.to_dict()
        if "device" in config_d:
            config_d["device"] = str(config_d["device"])
        json.dump(config_d, f, indent=4)
    print(f"Saved the configs at {out_config_path}")
    print("Done!")


if __name__ == "__main__":
    app.run(main)
