import argparse
import os
import pickle
import sys
from pathlib import Path

import numpy as np
import torch
import yaml

current_path = Path(os.path.abspath(__file__)).parent
project_root = current_path.parent
sys.path.append(project_root.as_posix())

from compression_autoencoder.autoencoders.rl_neural_continuous_autoencoder import (
    RLNeuralContinuousAutoencoder,
)
from compression_autoencoder.policies.policy import Policy
from compression_autoencoder.utils.evaluation import (
    multi_reward_evaluate,
    multi_reward_visualize,
)
from compression_autoencoder.utils.misc import (
    resolve_source_dir,
    set_seeds,
)
from scripts.constants import INPUT_SCALERS, WRAPPER_CLASSES

with open(current_path / "parameters.yml") as f:
    stored_params = yaml.load(f, Loader=yaml.SafeLoader)


def prep_arg_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Visualize the latent space of a trained autoencoder."
    )
    defaults = stored_params["defaults"]
    parser.add_argument(
        "--source_dir",
        type=str,
        help="Location of directory that contains the trained autoencoder files",
        required=True,
    )
    parser.add_argument(
        "--seed", type=int, default=defaults["seed"], help="Seed for RNG"
    )
    parser.add_argument(
        "--num_points",
        type=int,
        default=defaults["num_points"],
        help="Number of points to consider for each latent dimension",  # not really
    )
    parser.add_argument(
        "--num_envs",
        type=int,
        default=defaults["num_envs"],
        help="Number of environments to use for evaluation",
    )
    parser.add_argument(
        "--stats",
        nargs="?",
        default=defaults["stats"],
        help="Mountain car rewards to evaluate",
    )
    parser.add_argument(
        "--num_jobs",
        type=int,
        default=defaults["num_jobs"],
        help="Number of CPU jobs to run in parallel",
    )
    parser.add_argument(
        "--use_existing",
        action="store_true",
        help="If specified, will use existing decoded weights if they exist",
    )
    parser.add_argument(
        "--remove_color_bar",
        action="store_true",
        help="If specified, will remove the color bar from the plots",
    )
    return parser


def main() -> None:
    args = prep_arg_parser().parse_args()
    source_dir = resolve_source_dir(args.source_dir, project_root, current_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    set_seeds(args.seed)

    train_codes = np.load(source_dir / "train_codes.npy")

    if not args.use_existing:
        with open(source_dir / "training_args.yml") as f:
            training_args = yaml.load(f, Loader=yaml.SafeLoader)

        selected_dir = resolve_source_dir(
            training_args["source_dir"], project_root, current_path
        )

        with open(selected_dir / "selection_args.yml") as f:
            selection_args = yaml.load(f, Loader=yaml.SafeLoader)

        generation_dir = resolve_source_dir(
            selection_args["source_dir"], project_root, current_path
        )

        with open(generation_dir / "args.yml") as f:
            generation_args = yaml.load(f, Loader=yaml.FullLoader)

        policy_args = stored_params["policy"]
        policy_size = generation_args["policy_shape"]
        layer_shapes = [tuple(item) for item in policy_args["layer_shapes"][policy_size]]

        sample_policy = Policy(
            layer_shapes=layer_shapes,
            activation_func=generation_args["activation_func"],
            last_activation_func=generation_args["last_activation_func"],
            input_scaler=INPUT_SCALERS[stored_params["env"]],
            device=device,
        )

        model = RLNeuralContinuousAutoencoder(
            sample_policy=sample_policy,
            latent_dim=training_args["latent_dim"],
            encoder_layers_shapes=training_args["encoder_layers_shapes"],
            decoder_layers_shapes=training_args["decoder_layers_shapes"],
            activation_func=training_args["activation_func"],
            loss_func=training_args["loss_func"],
            input_scaler=INPUT_SCALERS[f"Autoencoder{stored_params['env']}"],
            device=device,
        )

        model.load(source_dir / "autoencoder.pth")

        # Create a grid in code (latent) space
        n_points = args.num_points
        code_dim = train_codes.shape[1]

        # Compute min and max for each dimension, add padding of 1
        # Remove outliers using the interquartile range (IQR) method
        q1 = np.percentile(train_codes, 25, axis=0)
        q3 = np.percentile(train_codes, 75, axis=0)
        iqr = q3 - q1
        lower_bound = q1 - 1.5 * iqr
        upper_bound = q3 + 1.5 * iqr
        # Keep only codes within the bounds for all dimensions
        mask = np.all((train_codes >= lower_bound) & (train_codes <= upper_bound), axis=1)
        filtered_codes = train_codes[mask]
        mins = np.min(filtered_codes, axis=0) - 1
        maxs = np.max(filtered_codes, axis=0) + 1

        # Create grid for each dimension
        axes = [np.linspace(mins[i], maxs[i], n_points) for i in range(code_dim)]

        # Create grid_codes for any code_dim
        mesh = np.meshgrid(*axes)
        grid_codes = np.stack([m.ravel() for m in mesh], axis=1)

        grid_codes_tensor = torch.tensor(grid_codes, dtype=torch.float32).to(device)

        # Decode grid codes to weights
        with torch.no_grad():
            decoded_weights = model.decode(grid_codes_tensor).detach()

        stats = multi_reward_evaluate(
            env_id=stored_params["env"],
            wrapper_class=WRAPPER_CLASSES[stored_params["env"]],
            seed=args.seed,
            layer_shapes=layer_shapes,
            activation_fn=policy_args["activation_func"],
            last_activation_fn=policy_args["last_activation_func"],
            policies_weights=decoded_weights,
            stats_to_collect=args.stats,
            input_scaler=INPUT_SCALERS[stored_params["env"]],
            eval_batch_size=args.num_envs,  # num_envs is the batch size
            n_jobs=args.num_jobs,
        )
    else:
        grid_codes = np.load(source_dir / "latent" / "grid_codes.npy")
        with open(source_dir / "latent" / "viz_stats.pkl", "rb") as f:
            stats = pickle.load(f)

    multi_reward_visualize(
        grid_codes, args.stats, stats, source_dir / "latent", args.num_points, not args.remove_color_bar
    )

    # Save results for this chunk
    np.save(source_dir / "latent" / "grid_codes.npy", grid_codes)
    save_path = source_dir / "latent"
    save_path.mkdir(parents=True, exist_ok=True)
    with open(f"{save_path}/viz_stats.pkl", "wb") as f:
        pickle.dump(stats, f)

    testing_args = vars(args)
    with open(save_path / "testing_args.yml", "w") as f:
        yaml.dump(testing_args, f, sort_keys=False)


if __name__ == "__main__":
    main()
