import argparse
import os
import sys
from math import sqrt
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from tqdm import tqdm

current_path = Path(os.path.abspath(__file__)).parent
project_root = current_path.parent
sys.path.append(project_root.as_posix())

from compression_autoencoder.autoencoders.rl_neural_continuous_autoencoder import (
    RLNeuralContinuousAutoencoder,
)
from compression_autoencoder.policies.policy import Policy
from compression_autoencoder.utils.misc import resolve_source_dir, set_seeds
from compression_autoencoder.utils.policy_dataset import (
    PolicyDataset,
    custom_collate_fn,
)
from scripts.constants import INPUT_SCALERS

with open(current_path / "parameters.yml") as f:
    stored_params = yaml.load(f, Loader=yaml.SafeLoader)


def prep_arg_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Train and evaluate the autoencoder using the selected policies"
    )
    defaults = stored_params["defaults"]
    parser.add_argument(
        "--source_dir",
        type=str,
        help="Location of directory that contains the selected policies' files",
        required=True,
    )
    parser.add_argument(
        "--seed", type=int, default=defaults["seed"], help="Seed for RNG"
    )
    parser.add_argument(
        "--latent_dim",
        type=int,
        default=defaults["latent_dim"],
        help="Latent space dimension",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=defaults["epochs_aut"],
        help="Max number of epochs to train",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=defaults["batch_size"],
        help="Size of training batches",
    )
    parser.add_argument(
        "--validation_split",
        type=float,
        default=defaults["validation_split"],
        help="Proportion of policies to be used as validation",
    )
    parser.add_argument(
        "--n_states_per_net",
        type=float,
        default=defaults["n_states_per_net"],
        help="Number of states on which each policy is evaluated at each gradient step.",
    )
    parser.add_argument(
        "--use_saved_states",
        action="store_true",
        help="If true, use the saved states to train the autoencoder.",
    )
    parser.add_argument(
        "--percentile",
        type=float,
        default=0.2,
        choices=stored_params["keep_percentages"],
        help="Which percentile dataset to use for training the autoencoder.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=stored_params["autoencoder"]["learning_rate"],
        help="Learning rate for training the autoencoder.",
    )

    return parser


def generate_dataset(
    weights: np.ndarray,
    states: np.ndarray,
    sample_policy: Policy,
    rng: np.random.Generator,
    batch_size: int,
    states_per_net: int,
    device: str,
    shuffle: bool = True,
) -> DataLoader:
    assert states.shape[0] == weights.shape[0], (
        "Number of states and weights must match"
    )
    processing_batch_size = 250

    num_samples = weights.shape[0]
    all_actions_list = []

    for i in tqdm(
        range(0, num_samples, processing_batch_size), desc="Generating dataset..."
    ):
        start_idx = i
        end_idx = min(i + processing_batch_size, num_samples)
        weights_chunk = weights[start_idx:end_idx]
        states_chunk = states[start_idx:end_idx]

        weights_tensor = torch.from_numpy(weights_chunk).float().to(device)
        states_tensor = torch.from_numpy(states_chunk).float().to(device)

        with torch.no_grad():
            actions_chunk = sample_policy.forward(states_tensor, weights_tensor)

        all_actions_list.append(actions_chunk.cpu())

    actions_tensor = torch.cat(all_actions_list, dim=0)
    dataset = PolicyDataset(
        weights, states, actions_tensor.numpy(), states_per_net, rng, device
    )
    return DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=custom_collate_fn
    )


def plot_latent_space(
    ax: plt.Axes,
    train_codes: np.ndarray,
    val_codes: np.ndarray,
) -> None:
    """
    Plots the distribution of training and validation data in the latent space.
    This function automatically handles 1D, 2D, and 3D latent spaces.

    Args:
        ax: The matplotlib Axes object to plot on.
        train_codes: Latent codes for the training dataset.
        val_codes: Latent codes for the validation dataset.
    """
    latent_dim = train_codes.shape[1]
    is_3d = ax.name == "3d"

    # Prepare data points
    x_train, y_train, z_train = train_codes[:, 0], None, None
    x_val, y_val, z_val = val_codes[:, 0], None, None

    if latent_dim > 1:
        y_train, y_val = train_codes[:, 1], val_codes[:, 1]
    else:  # 1D case
        y_train, y_val = np.zeros_like(x_train), np.zeros_like(x_val)

    if is_3d and latent_dim > 2:
        z_train, z_val = train_codes[:, 2], val_codes[:, 2]
    elif is_3d:  # 2D data on 3D plot
        z_train, z_val = np.zeros_like(x_train), np.zeros_like(x_val)

    # Plotting
    if is_3d:
        ax.scatter(x_train, y_train, z_train, s=1, alpha=0.7, label="Train")
        ax.scatter(x_val, y_val, z_val, s=5, alpha=0.9, label="Validation", marker="s")
        ax.set_zlabel("Latent Code 3")
    else:
        ax.scatter(x_train, y_train, s=1, alpha=0.7, label="Train")
        ax.scatter(x_val, y_val, s=5, alpha=0.9, label="Validation", marker="s")

    # Aesthetics and labels
    ax.set_title("Dataset Distribution in Latent Space")
    ax.set_xlabel("Latent Code 1")
    ax.set_ylabel("Latent Code 2")
    ax.legend(markerscale=4)
    ax.grid(True, linestyle="--", alpha=0.6)


def main() -> None:
    args = prep_arg_parser().parse_args()
    source_dir = resolve_source_dir(args.source_dir, project_root, current_path)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    rng = set_seeds(args.seed)

    with open(source_dir / "selection_args.yml") as f:
        selection_args = yaml.load(f, Loader=yaml.FullLoader)

    generation_dir = resolve_source_dir(
        selection_args["source_dir"].replace("\\", "/"), project_root, current_path
    )
    with open(generation_dir / "args.yml") as f:
        generation_args = yaml.load(f, Loader=yaml.FullLoader)

    ## POLICY PARAMETERS
    policy_args = stored_params["policy"]
    policy_size = generation_args["policy_shape"]
    layer_shapes = [tuple(item) for item in policy_args["layer_shapes"][policy_size]]

    sample_policy = Policy(
        layer_shapes=layer_shapes,
        activation_func=generation_args["activation_func"],
        last_activation_func=generation_args["last_activation_func"],
        input_scaler=INPUT_SCALERS[stored_params["env"]],
        device=device,
    )

    ## AUTOENCODER PARAMETERS
    autoencoder_args = stored_params["autoencoder"]

    loss_f = autoencoder_args["loss_func"]
    activation = autoencoder_args["activation_func"]
    latent_dim = args.latent_dim
    learning_rate = args.learning_rate

    encoder_layers = autoencoder_args["encoder_layers_shapes"]
    decoder_layers = autoencoder_args["decoder_layers_shapes"]

    input_dim, _ = sample_policy.count_params()

    # modify input and output dimensions of encoder and decoder with the correct values
    encoder_layers[0][0] = decoder_layers[-1][1] = input_dim
    encoder_layers[-1][1] = decoder_layers[0][0] = latent_dim

    encoder_layers_shapes = [tuple(item) for item in encoder_layers]
    decoder_layers_shapes = [tuple(item) for item in decoder_layers]

    model = RLNeuralContinuousAutoencoder(
        sample_policy=sample_policy,
        latent_dim=latent_dim,
        encoder_layers_shapes=encoder_layers_shapes,
        decoder_layers_shapes=decoder_layers_shapes,
        activation_func=activation,
        loss_func=loss_f,
        learning_rate=learning_rate,
        scheduler_kwargs=autoencoder_args["training_args"]["lr_scheduler"],
        early_stopping_kwargs=autoencoder_args["training_args"]["early_stopping"],
        input_scaler=INPUT_SCALERS[f"Autoencoder{stored_params['env']}"],
        device=device,
    )

    ## DATASET AND TRAINING PARAMETERS
    policies_weights = np.load(
        source_dir / f"keep_{int(args.percentile * 100)}p" / "selected_weights.npy"
    )
    num_policies = policies_weights.shape[0]

    train_indices, val_indices = train_test_split(
        list(range(num_policies)),
        test_size=args.validation_split,
        random_state=rng.integers(2**32 - 1),
    )
    train_weights = policies_weights[train_indices]
    val_weights = policies_weights[val_indices]

    if args.use_saved_states:
        states = np.load(source_dir / "states.npy")
    else:
        # TODO: generalize to other environments
        n_states = selection_args["num_states"]
        n_points = int(sqrt(n_states))
        s1 = np.linspace(-1.2, 0.6, n_points)
        s2 = np.linspace(-0.07, 0.07, n_points)
        ss1, ss2 = np.meshgrid(s1, s2)
        states = np.stack([ss1.ravel(), ss2.ravel()], axis=1)
    train_states = np.tile(states, (train_weights.shape[0], 1, 1))
    val_states = np.tile(states, (val_weights.shape[0], 1, 1))

    common_args = {
        "sample_policy": sample_policy,
        "rng": rng,
        "batch_size": args.batch_size,
        "states_per_net": args.n_states_per_net,
        "device": device,
    }
    train_loader = generate_dataset(
        weights=train_weights, states=train_states, **common_args
    )
    val_loader = generate_dataset(
        weights=val_weights, states=val_states, shuffle=False, **common_args
    )

    history = model.fit(train_loader, val_loader, args.epochs)

    dest_dir = (
        project_root
        / f"trained_autoencoders_keep_{int(args.percentile * 100)}p_lr_{args.learning_rate}_dim_{args.latent_dim}"
        / f"{source_dir.name}"
    )
    print(f"Saving model to {dest_dir} ...")
    dest_dir.mkdir(parents=True, exist_ok=True)
    model.save(dest_dir / "autoencoder.pth")

    # visualize training history and dataset distribution in latent space
    plt.style.use("tableau-colorblind10")

    # Code distribution in the latent space
    with torch.no_grad():
        train_codes = (
            model.encode(torch.from_numpy(train_weights).float().to(device))
            .cpu()
            .numpy()
        )
        val_codes = (
            model.encode(torch.from_numpy(val_weights).float().to(device)).cpu().numpy()
        )

    fig = plt.figure(figsize=(16, 7))

    if args.latent_dim == 3:
        ax1 = fig.add_subplot(1, 2, 1)
        ax2 = fig.add_subplot(1, 2, 2, projection="3d")
    else:
        ax1 = fig.add_subplot(1, 2, 1)
        ax2 = fig.add_subplot(1, 2, 2)

    history.visualize(ax=ax1)
    ax1.set_title("Training History")
    
    if args.latent_dim <= 3:
        plot_latent_space(ax2, train_codes, val_codes)
    else:
        ranges = []
        for i in range(args.latent_dim):
            dim_min = train_codes[:, i].min()
            dim_max = train_codes[:, i].max()
            ranges.append(f"Dim {i+1}: [{dim_min:.3f}, {dim_max:.3f}]")
        range_text = "Latent space ranges:\n" + "\n".join(ranges)
        ax2.text(
            0.5,
            0.5,
            range_text,
            horizontalalignment="center",
            verticalalignment="center",
            transform=ax2.transAxes,
            fontsize=14,
        )
        ax2.set_axis_off()
    output_path = dest_dir / "training_summary.pdf"
    fig.tight_layout()
    plt.savefig(output_path)

    # Save relevant info
    history.save(dest_dir / "training_history.pkl")
    np.save(dest_dir / "train_indices.npy", train_indices)
    np.save(dest_dir / "val_indices.npy", val_indices)
    np.save(dest_dir / "train_codes.npy", train_codes)
    np.save(dest_dir / "val_codes.npy", val_codes)
    # Save params for reproducibility
    training_args = vars(args)
    training_args["loss_func"] = loss_f
    training_args["activation_func"] = activation
    training_args["encoder_layers_shapes"] = encoder_layers
    training_args["decoder_layers_shapes"] = decoder_layers
    training_args["training_args"] = autoencoder_args["training_args"]
    with open(dest_dir / "training_args.yml", "w") as f:
        yaml.dump(training_args, f, sort_keys=False)


if __name__ == "__main__":
    main()
