import argparse
import json
import os
import torch
from torchvision import transforms
import matplotlib.pyplot as plt

from tqdm import tqdm
from scripts.main import get_device, get_final_model_path, get_latest_checkpoint
from mixed_diffusion.data_loading.data_loading import get_data
from mixed_diffusion.helpers import get_beta_schedule
from mixed_diffusion.models.get_model import get_model
from mixed_diffusion.sampling import sample_images

from torch.utils.data import DataLoader

import numpy as np

from mixed_diffusion.utils import save_3d_tensor
from mixed_diffusion.visualize import plot_heatmap, plot_scatters_from_dict
from mixed_diffusion.wasserstein_distance import wasserstein_distance_from_samples


def binary_array(i, length=5):
    return [int(x) for x in format(i, f"0{length}b")]


def main(args):
    print("Running the process with the following arguments:", args)

    if os.path.exists(args.result_dir):
        print(f"Result directory {args.result_dir} already exists.")
        # return
    else:
        os.makedirs(args.result_dir)

    with open(f"{args.model_dir}/config.json", "r") as f:
        config = json.load(f)

    print("Loaded config:")
    print(json.dumps(config, indent=4))

    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]
    )
    train, test, info = get_data(args, transform)
    test_dataloader = DataLoader(train, batch_size=args.batch_size, shuffle=False)

    x0, _ = next(iter(test_dataloader))

    print(f"Loaded data with shape {x0.shape}")

    device = get_device(args)

    model = get_model(config, x0.shape, device)

    model_path = get_final_model_path(args.model_dir)
    if model_path is None:
        model_path = get_latest_checkpoint(args.model_dir)
    model.load_state_dict(
        torch.load(model_path, map_location=device, weights_only=True)
    )
    print(f"Model loaded from {model_path}")

    # Generate 3 random images from the diffusion model
    num_random_samples = 100
    print(
        f"Generating {num_random_samples} random samples with random noise initialization and timestep {config['noise_step']}..."
    )
    data = {}

    for i in tqdm(range(config["num_classes"] ** 2)):

        # calculate base conditioning vector as binary representation of i
        base_conditioning_vector = (
            torch.tensor(binary_array(i, config["num_classes"]))
            .to(x0.device)
            .to(torch.float32)
        )
        # normalize the base conditioning vector
        base_conditioning_vector /= (
            base_conditioning_vector.sum() if base_conditioning_vector.sum() > 0 else 1
        )
        conditioning_vector = base_conditioning_vector.repeat(num_random_samples, 1)

        random_samples, _ = sample_images(
            config,
            model,
            image_size=x0.shape[1:],
            num_samples=num_random_samples,
            conditioning_vector=conditioning_vector,
        )
        data[f"{i:0{config['num_classes']}b}"] = random_samples

    # Determine visualization method
    visualization_method = "umap" if args.umap else "pca"
    
    plot_scatters_from_dict(
        data,
        title_prefix="",
        save_path=f"{args.result_dir}/random_samples.png",
        archetypes=info.get("archetypes", None),
        method=visualization_method,
    )


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Train a diffusion model.")
    parser.add_argument(
        "--data_file",
        type=str,
        default="/home/ubuntu/data",
        help="Path to the data directory.",
    )
    parser.add_argument(
        "--model_dir",
        type=str,
        default="/home/ubuntu/models/mixed_diffusion",
        help="Directory to load the diffusion model from",
    )
    parser.add_argument(
        "--result_dir",
        type=str,
        default=".",
        help="Directory to save the results.",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="cifar10",
        help="Dataset to use for training.",
    )
    parser.add_argument("--mps", action="store_true", help="Use MPS for training.")
    parser.add_argument("--grid_plot", action="store_true", help="Use grid plot.")
    parser.add_argument("--show", action="store_true", help="Show the images.")
    parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
    parser.add_argument(
        "--wasserstein_distance",
        action="store_true",
        help="Compute Wasserstein distance.",
    )
    parser.add_argument(
        "--config_file", type=str, help="Path to the config file for data generation."
    )
    parser.add_argument(
        "--pca",
        action="store_true",
        help="Use PCA for dimensionality reduction.",
    )
    parser.add_argument(
        "--umap",
        action="store_true",
        help="Use UMAP for dimensionality reduction.",
    )
    args = parser.parse_args()
    args.model_dir = args.model_dir.rstrip("/")
    args.result_dir = args.result_dir.rstrip("/")
    args.num_samples = 10000

    main(args)
