import os
from pathlib import Path
from pprint import pprint

import numpy as np
import torch
from absl import app, flags
from gen_neg_toy.utils import logging
import subprocess

import wandb
from gen_neg_toy import (
    data,
    dispatch_model,
    dispatch_model_from_path,
    script_utils,
)
from gen_neg_toy.ng_utils import (
    compute_infraction,
    compute_infraction_differentiable,
)
from gen_neg_toy.evaluation import visualize_hist, visualize_scatter


FLAGS = flags.FLAGS
flags.DEFINE_integer("n_iters", 10, "Number of iterations.")
flags.DEFINE_integer("synth_dataset_size", 10000, "Size of synthetic datasets.")
flags.DEFINE_integer(
    "n_train_iters", 20000, "Number of training iterations for the classifier."
)
flags.DEFINE_integer("correct_ratio", 1, "If zero, will use 50/50 ratio for training the classifier")
flags.DEFINE_string("baseline_checkpoint", "", "Path to baseline checkpoint.")
flags.mark_flags_as_required(["baseline_checkpoint"])

logging.support_unobserve()


@torch.no_grad()
def test_model(checkpoint, classifier=None, n_trials=5):
    model, config = dispatch_model_from_path(
        checkpoint,
        strict=(classifier is None),
        classifier=classifier,
    )

    ## Load the dataset ##
    _, test_set = data.get_datasets(config.data)
    val_loader = torch.utils.data.DataLoader(
        test_set, batch_size=config.training.batch_size, shuffle=False
    )

    ## ELBO computation ##
    elbos = np.stack(
        [
            script_utils.elbo_from_dataloader(
                model, val_loader, device=config.device, num_steps=100
            )
            for _ in range(n_trials)
        ]
    )
    elbo_mean = elbos.mean()
    elbo_std = elbos.std()

    infractions = []
    infraction_dists = []
    for _ in range(n_trials):
        samples, nfe = script_utils.draw_samples(
            model,
            n_samples=20000,
            device=config.device,
            S_churn=10,
            num_steps=100,
            rho=7,
            verbose=False,
            max_batch_size=2048,
        )
        infraction = compute_infraction(samples).cpu().numpy()
        infractions.append(infraction.astype("float").mean() * 100)
        infraction_dist = (
            compute_infraction_differentiable(samples, norm_p=1).cpu().numpy()
        )
        if infraction.astype("float").sum() > 0:
            infraction_dist = infraction_dist[infraction]
            infraction_dists.append(infraction_dist)
    infraction_dists = np.concatenate(infraction_dists)
    infraction_mean = np.mean(infractions)
    infraction_std = np.std(infractions)
    infraction_dist_mean = infraction_dists.mean()
    infraction_dist_std = infraction_dists.max()

    samples_scatter = visualize_scatter(samples)
    samples_hist = visualize_hist(samples)

    return dict(
        elbo_mean=elbo_mean,
        elbo_std=elbo_std,
        infraction_mean=infraction_mean,
        infraction_std=infraction_std,
        infraction_dist_mean=infraction_dist_mean,
        infraction_dist_std=infraction_dist_std,
        samples_scatter=wandb.Image(samples_scatter),
        samples_hist=wandb.Image(samples_hist),
    )


def run_cmd(cmd):
    cmd = " ".join(cmd)
    print(f"### Running command: {cmd}")
    exit_code = subprocess.call(cmd, shell=True)
    assert exit_code == 0, f"Command failed"


def main(argv):
    seed = None
    logging.init(
        config=dict(
            n_iters=FLAGS.n_iters,
            synth_dataset_size=FLAGS.synth_dataset_size,
            n_train_iters=FLAGS.n_train_iters,
            correct_ratio=FLAGS.correct_ratio,
            baseline_checkpoint=FLAGS.baseline_checkpoint,
        ),
        project_name="gen_model_neg_iterative",
    )
    unique_name = wandb.run.id
    print(f"ID: {unique_name}")

    baseline_checkpoint = FLAGS.baseline_checkpoint
    # training_alpha = 0.9945
    cls_checkpoints_dir = Path(f"results/checkpoints_cls_iterative/{unique_name}")
    synth_dataset_size = FLAGS.synth_dataset_size
    n_train_iters = FLAGS.n_train_iters

    eval_results = test_model(baseline_checkpoint)
    print("(iteration {pipeline_iter}) Evaluation results:")
    pprint(eval_results)
    training_alpha = 1 - eval_results["infraction_mean"] / 100.0
    wandb.log(eval_results)

    def gen_dataset(iteration, cls_checkpoints=[]):
        neg_dataset_name = f"pipeline/{unique_name}_{iteration}"
        cmd = ["python", "scripts/generate_dataset.py"]
        cmd.append(f"--config.checkpoint {baseline_checkpoint}")
        cmd.append(f"--config.out {neg_dataset_name}")
        cmd.append(f"--config.n_samples {synth_dataset_size}")
        if cls_checkpoints:
            cmd.append(f"--config.classifier {','.join(cls_checkpoints)}")
        run_cmd(cmd)
        print(f">>> (iteration {iteration}) Samples from the model are generated.")
        return neg_dataset_name

    def train_classifier(iteration, cls_checkpoints, neg_dataset_name, training_alpha):
        cmd = ["python", "scripts/train_classifier.py"]
        cmd.append(
            f"--config.data.dataset results/neg_dataset/{neg_dataset_name}_pos.npy"
        )
        cmd.append(
            f"--config.data.neg_dataset results/neg_dataset/{neg_dataset_name}.npy"
        )
        cmd.append(f"--config.data.neg_dataset_size={synth_dataset_size}")
        cmd.append(f"--config.data.train_set_size={synth_dataset_size}")
        if FLAGS.correct_ratio != 0:
            cmd.append(f"--config.training.alpha {training_alpha}")
        cmd.append(f"--config.training.n_iters {n_train_iters}")
        cmd.append(f"--config.training.out_dir {cls_checkpoints_dir}/iter_{iteration}")
        cmd.append(f"--tags pipeline")
        cmd.append(f"--unobserve")
        run_cmd(cmd)
        print(f">>> (iteration {iteration}) Classifier is trained.")
        return f"{cls_checkpoints_dir}/iter_{pipeline_iter}/final_{n_train_iters}.pt"

    # Generate synthetic datasets from the unconditional model
    neg_dataset_name = gen_dataset(0)

    cls_checkpoints = []
    for pipeline_iter in range(1, FLAGS.n_iters):
        if seed is not None:
            seed += 1
        new_classifier_path = train_classifier(
            iteration=pipeline_iter,
            cls_checkpoints=cls_checkpoints,
            neg_dataset_name=neg_dataset_name,
            training_alpha=training_alpha,
        )
        cls_checkpoints.append(new_classifier_path)

        eval_results = test_model(baseline_checkpoint, classifier=cls_checkpoints)
        wandb.log(eval_results)
        print("(iteration {pipeline_iter}) Evaluation results:")
        pprint(eval_results)
        training_alpha = 1 - eval_results["infraction_mean"] / 100.0

        # Generate the next iteration's datasets
        neg_dataset_name = gen_dataset(pipeline_iter, cls_checkpoints=cls_checkpoints)


if __name__ == "__main__":
    app.run(main)