import numpy as np
import os
import torch
from ddpm_torch.toy import *
from ddpm_torch.utils import seed_all, infer_range
from torch.optim import Adam, lr_scheduler
from matplotlib import pyplot as plt
from argparse import ArgumentParser
import wandb
from ddpm_torch.toy import GenToyDataset
# from ddpm_torch.toy_utils import Evaluator_1D
from metrics import compute_metrics

def parse_arguments():

    parser = ArgumentParser()

    parser.add_argument("--dataset", choices=["gaussian8", "gaussian25", "swissroll", "gaussian25_imbalanced", "gaussian2_1d", 
                                               "gaussian_mixture_2d", "gaussian_nd_zeros", 
                                               "gaussian_nd_more_modes", "gaussian_nd_odd_even", 
                                               "gaussian2d_composition_test",
                                               "gaussian25_rotated",
                                               "gaussian25_no_std",
                                               "gaussian3_1d"], default="gaussian8")
    parser.add_argument("--num_sampled_images", default=10_000_000, type=int)
    parser.add_argument("--root", default="~/datasets", type=str, help="root directory of datasets")
    parser.add_argument("--beta1", default=0.9, type=float, help="beta_1 in Adam")
    parser.add_argument("--beta2", default=0.999, type=float, help="beta_2 in Adam")
    parser.add_argument("--lr-warmup", default=0, type=int, help="number of warming-up epochs")
    parser.add_argument("--batch-size", default=10000, type=int)
    parser.add_argument("--timesteps", default=1000, type=int, help="number of diffusion steps")
    parser.add_argument("--beta-schedule", choices=["quad", "linear", "warmup10", "warmup50", "jsd"], default="linear") 
    parser.add_argument("--beta-start", default=0.001, type=float)
    parser.add_argument("--beta-end", default=0.2, type=float)
    parser.add_argument("--model-mean-type", choices=["mean", "x_0", "eps"], default="eps", type=str)
    parser.add_argument("--model-var-type", choices=["learned", "fixed-small", "fixed-large"], default="fixed-large", type=str)  # noqa
    parser.add_argument("--loss-type", choices=["kl", "mse"], default="mse", type=str)
    parser.add_argument("--image-dir", default="./images/train", type=str)
    parser.add_argument("--exp_str", default="0", type=str)
    parser.add_argument("--chkpt-dir", default="./chkpts", type=str)
    parser.add_argument("--chkpt-intv", default=100, type=int, help="frequency of saving a checkpoint")
    parser.add_argument("--eval-intv", default=10, type=int)
    parser.add_argument("--seed", default=1234, type=int, help="random seed")
    parser.add_argument("--resume", action="store_true", help="to resume training from a checkpoint")
    parser.add_argument("--device", default="cuda:0", type=str)
    parser.add_argument("--mid-features", default=128, type=int)
    parser.add_argument("--num-temporal-layers", default=3, type=int)
    parser.add_argument("--generations", default=2, type=int)
    parser.add_argument("--log_results", action="store_true", help="log results to wandb")

    args = parser.parse_args()
    return args

def main():

    args = parse_arguments()
    args.store_name = "_".join([ 'resample', str(int(args.num_sampled_images/1e6)), 'm',
        args.chkpt_dir.split("/")[-1]
    ])
    # set seed
    seed_all(args.seed)
    print(args)

    if args.log_results:
        wandb.init(project="synthetic",
                                   entity="neurips", name=args.store_name)
        wandb.config.update(args)
        wandb.run.log_code(".")


    # prepare toy data
    in_features = 1 if "1d" in args.dataset else 2
    if args.dataset == "gaussian_nd_zeros" or args.dataset == "gaussian_nd_more_modes" or args.dataset == "gaussian_nd_odd_even":
        in_features = 32
    dataset = args.dataset
    # chkpt_dir = args.chkpt_dir + f"/{args.store_name}"
    # if not os.path.exists(chkpt_dir):
    #     os.makedirs(chkpt_dir)

    def sample_fn(n, model, diffusion, shape , device='cuda'):
        shape = (n,) + shape
        sample, noise = diffusion.p_sample_save(
            denoise_fn=model, shape=shape, device=device, noise=None)
        return sample.cpu().numpy(), noise.cpu().numpy()
    
    for gen in range(args.generations):
        # training parameters
        device = torch.device(args.device)

        # diffusion parameters
        beta_schedule = args.beta_schedule
        beta_start, beta_end = args.beta_start, args.beta_end
        timesteps = args.timesteps
        betas = get_beta_schedule(
            beta_schedule, beta_start=beta_start, beta_end=beta_end, timesteps=timesteps)
        model_mean_type = args.model_mean_type
        model_var_type = args.model_var_type
        loss_type = args.loss_type
        diffusion = GaussianDiffusion(
            betas=betas, model_mean_type=model_mean_type, model_var_type=model_var_type, loss_type=loss_type)

        mid_features = args.mid_features
        model = Decoder(in_features, mid_features, args.num_temporal_layers)
        model.to(device)

        chkpt = torch.load(f"{args.chkpt_dir}/ddpm_{dataset}_gen_{gen}.pt")
        model.load_state_dict(chkpt["model"])
        model.eval()

        number_of_samples = args.num_sampled_images
        eval_batch_size = 1_000_000
        x_gen = []
        shape = (1,)
        noise_vec = []
        for j in range(0, number_of_samples, eval_batch_size):
            print(j)
            samples, noise_v = sample_fn(eval_batch_size, model, diffusion, shape)
            x_gen.extend(samples)
            noise_vec.extend(noise_v)
        x_gen = np.array(x_gen)
        noise_vec = np.array(noise_vec)
        print(x_gen.shape)
        np.save(f"{args.chkpt_dir}/gen_dataset_{gen}_{args.num_sampled_images}_save.npy", x_gen)
        np.save(f"{args.chkpt_dir}/gen_dataset_{gen}_{args.num_sampled_images}_save_noise_vec.npy", noise_vec)
        # print(gen_dataset.shape)
        if args.dataset == "gaussian2_1d" or args.dataset == "gaussian3_1d":
            # Set log scale
            plt.yscale("log")
            plt.hist(x_gen, bins=100, alpha=0.7, edgecolor='black')
        else:
            plt.scatter(*np.hsplit(x_gen, 2), s=0.5, alpha=0.7)
        plt.tight_layout()
        # Earlier
        # plt.savefig(f"{image_dir}/{gen}.jpg")
        plt.savefig(f"{args.chkpt_dir}/generatd_dataset_{gen}_{args.num_sampled_images}.png")
        plt.close()

        if args.log_results:
            wandb.log({f"Gen": wandb.Image(f"{args.chkpt_dir}/generatd_dataset_{gen}_{args.num_sampled_images}.png", caption=f"Gen {gen}")})
        # assert gen_dataset.shape == (data_size, 2) or gen_dataset.shape == (data_size, 1) or gen_dataset.shape == (data_size, 32)

if __name__ == "__main__":
    main()
