import argparse
import gc
import os
import pickle
import sys
from math import sqrt
from pathlib import Path

import numpy as np
import torch
import yaml
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm

current_path = Path(os.path.abspath(__file__)).parent
project_root = current_path.parent
sys.path.append(project_root.as_posix())

from compression_autoencoder.policies.policy import Policy
from compression_autoencoder.utils.misc import resolve_source_dir, set_seeds
from scripts.constants import INPUT_SCALERS

with open(current_path / "parameters.yml") as f:
    stored_params = yaml.load(f, Loader=yaml.SafeLoader)


def prep_arg_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Select a diverse subset of policies using KNN rejection sampling."
    )
    defaults = stored_params["defaults"]
    parser.add_argument(
        "--source_dir",
        type=str,
        help="Location of directory that contains the generated policies' files",
        required=True,
    )
    parser.add_argument(
        "--num_states",
        type=int,
        default=defaults["num_states_rej"],
        help="Number of states to consider for behavioral comparison.",
    )
    parser.add_argument(
        "--k_neighbors",
        type=int,
        default=defaults["k_neighbors_rej"],
        help="k neighbors for rejection sampling.",
    )
    parser.add_argument(
        "--num_jobs",
        type=int,
        default=defaults["num_jobs"],
        help="Number of CPU jobs to run in parallel",
    )
    parser.add_argument(
        "--skip_stats",
        action="store_true",
        help="Skip loading and saving stats files (saves memory if not needed)",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=defaults["seed"],
        help="Random seed for reproducibility",
    )
    return parser


def calculate_avg_distances(
    actions: np.ndarray,
    k_for_density: int,
    verbose: bool = True,
    n_jobs: int = 1,
) -> np.ndarray:
    """
    Calculates the average distance to the k-nearest neighbors for each policy.
    This is used as a score for density, where a higher distance means lower density.
    """
    n_policies = actions.shape[0]
    if verbose:
        print(
            f"Estimating density for {n_policies} policies using k={k_for_density}..."
        )

    nn_model = NearestNeighbors(n_neighbors=k_for_density + 1, n_jobs=n_jobs)
    nn_model.fit(actions)

    chunk_size = 5000
    all_distances = []

    for i in tqdm(range(0, n_policies, chunk_size), desc="  K-NN Search"):
        chunk = actions[i : i + chunk_size]
        distances_chunk, _ = nn_model.kneighbors(chunk)
        all_distances.append(distances_chunk)

    distances = np.concatenate(all_distances, axis=0)

    avg_distances = np.mean(distances[:, 1:], axis=1)

    if verbose:
        print("Density estimation complete.")
    return avg_distances


def main() -> None:
    args = prep_arg_parser().parse_args()
    source_dir = resolve_source_dir(args.source_dir, project_root, current_path)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    rng = set_seeds(args.seed)

    # Create state space grid
    if stored_params["env"] == "MountainCarContinuous-v0":
        n_states = args.num_states
        n_points = int(sqrt(n_states))
        n_states = n_points**2
        s1 = np.linspace(-1.2, 0.6, n_points)
        s2 = np.linspace(-0.07, 0.07, n_points)
        ss1, ss2 = np.meshgrid(s1, s2)
        states = np.stack([ss1.ravel(), ss2.ravel()], axis=1)
        states_tensor = torch.tensor(states, dtype=torch.float32).unsqueeze(0).to(device)
    else:
        n_states = args.num_states
        lb = np.array([-1, -1, -1, -1, -5, -5])
        ub = np.array([1, 1, 1, 1, 5, 5])
        states = rng.uniform(lb, ub, (n_states, 6))
        states_tensor = torch.tensor(states, dtype=torch.float32).unsqueeze(0).to(device)


    # Load policy architecture
    with open(source_dir / "args.yml") as f:
        gen_policies_args = yaml.load(f, Loader=yaml.FullLoader)

    policy_args = stored_params["policy"]
    policy_size = gen_policies_args["policy_shape"]
    layer_shapes = [tuple(item) for item in policy_args["layer_shapes"][policy_size]]

    sample_policy = Policy(
        layer_shapes=layer_shapes,
        activation_func=gen_policies_args["activation_func"],
        last_activation_func=gen_policies_args["last_activation_func"],
        input_scaler=INPUT_SCALERS[stored_params["env"]],
        device=device,
    )

    print("Loading all policy chunks into memory and evaluating actions...")
    weights_files = sorted(source_dir.glob("weights_*.npy"))
    stat_files = sorted(source_dir.glob("stats_*.pkl"))

    if not weights_files:
        raise FileNotFoundError(
            f"Error: No 'weights_*.npy' files found in {source_dir}"
        )

    n_policies_total = gen_policies_args["num_policies"]
    all_actions = torch.empty((n_policies_total, n_states, layer_shapes[-1][1]), device=device)

    processed_policies = 0
    # Loop through each file chunk of weights
    for weight_file in tqdm(weights_files):
        weights_chunk = np.load(weight_file)
        weights_tensor = torch.from_numpy(weights_chunk).float().to(device)
        n_policies_chunk = weights_tensor.shape[0]

        processing_batch_size = 250
        for i in range(0, n_policies_chunk, processing_batch_size):
            chunk_end = min(i + processing_batch_size, n_policies_chunk)
            num_in_chunk = chunk_end - i
            with torch.no_grad():
                actions_chunk = sample_policy.forward(
                    states_tensor, weights_tensor[i:chunk_end]
                ).squeeze(-1)

            all_actions[processed_policies : processed_policies + num_in_chunk] = (
                actions_chunk
            )
            processed_policies += num_in_chunk

        del weights_chunk, weights_tensor, actions_chunk
        if device.type == "cuda":
            torch.cuda.empty_cache()

    if processed_policies != n_policies_total:
        raise ValueError(
            f"Error: Expected {n_policies_total} policies, but found {processed_policies}."
        )

    if not args.skip_stats:
        all_stats_list = []
        for f in stat_files:
            with open(f, "rb") as file:
                all_stats_list.append(pickle.load(file))
        all_stats = {
            k: np.concatenate([s[k] for s in all_stats_list]) for k in all_stats_list[0]
        }

    actions_np = all_actions.cpu().numpy()
    actions_np = actions_np.reshape(n_policies_total, -1)
    del all_actions
    if device.type == "cuda":
        torch.cuda.empty_cache()

    avg_distances = calculate_avg_distances(
        actions=actions_np,
        k_for_density=args.k_neighbors,
        verbose=True,
        n_jobs=args.num_jobs,
    )

    print("Saving selected policies...")
    dest_dir = project_root / "selected_policies" / source_dir.name
    dest_dir.mkdir(parents=True, exist_ok=True)

    # Save the calculated distances for potential future use
    np.save(dest_dir / "avg_distances.npy", avg_distances)

    for keep_percentage in stored_params["keep_percentages"]:
        print(f"\nProcessing for {keep_percentage * 100:.1f}%...")

        # Determine the threshold from the pre-calculated distances
        percentile_to_cut = 100 * (1.0 - keep_percentage)
        distance_threshold = np.percentile(avg_distances, percentile_to_cut)
        selected_mask = avg_distances >= distance_threshold
        num_selected = np.sum(selected_mask)

        # Create a specific subdirectory for this percentage
        percent_dir = dest_dir / f"keep_{int(keep_percentage * 100)}p"
        percent_dir.mkdir(exist_ok=True)

        if not args.skip_stats:
            selected_stats = {k: v[selected_mask] for k, v in all_stats.items()}
            with open(percent_dir / "selected_stats.pkl", "wb") as f:
                pickle.dump(selected_stats, f)

        selected_weights = np.empty(
            (num_selected, Policy._count_params_from_shape(layer_shapes))
        )

        selected_idx = 0
        total_idx = 0
        for weight_file in tqdm(weights_files, desc="  Saving weights"):
            weights_chunk = np.load(weight_file)
            chunk_size = weights_chunk.shape[0]
            chunk_mask = selected_mask[total_idx : total_idx + chunk_size]
            num_selected_chunk = np.sum(chunk_mask)
            selected_weights[selected_idx : selected_idx + num_selected_chunk] = (
                weights_chunk[chunk_mask]
            )
            selected_idx += num_selected_chunk
            total_idx += chunk_size
            del weights_chunk

        np.save(percent_dir / "selected_weights.npy", selected_weights)

        print(f"Distance threshold: {distance_threshold:.6f}")
        del selected_weights
        del selected_mask
        gc.collect()
    
    np.save(dest_dir / "states.npy", states)
    # Save params for reproducibility
    selection_args = vars(args)
    selection_args["keep_percentages"] = stored_params["keep_percentages"]
    with open(dest_dir / "selection_args.yml", "w") as f:
        yaml.dump(selection_args, f, sort_keys=False)


if __name__ == "__main__":
    main()
