import argparse
import os
import pickle
import sys
from pathlib import Path

import numpy as np
import torch
import yaml

current_path = Path(os.path.abspath(__file__)).parent
project_root = current_path.parent
sys.path.append(project_root.as_posix())

from compression_autoencoder.policies.policy import Policy
from compression_autoencoder.utils.evaluation import multi_reward_evaluate
from compression_autoencoder.utils.misc import get_seed_sequence, set_seeds
from scripts.constants import INPUT_SCALERS, WRAPPER_CLASSES

with open(current_path / "parameters.yml") as f:
    stored_params = yaml.load(f, Loader=yaml.SafeLoader)


def prep_arg_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Generate and evaluate random policies for Mountain Car environment"
    )
    defaults = stored_params["defaults"]
    parser.add_argument(
        "--seed", type=int, default=defaults["seed"], help="Seed for RNG"
    )
    parser.add_argument(
        "--param_range", type=float, default=defaults["param_range"], help="RNG range"
    )
    parser.add_argument(
        "--num_envs",
        type=int,
        default=defaults["num_envs"],
        help="Number of policies to evaluate in parallel on one device (batch size)",
    )
    parser.add_argument(
        "--num_policies",
        type=int,
        default=defaults["num_policies"],
        help="Number of policies to generate",
    )
    parser.add_argument(
        "--num_expected_eps",
        type=int,
        default=defaults["num_expected_eps"],
        help="Number of expected episodes per policy",
    )
    parser.add_argument(
        "--chunk_size",
        type=int,
        default=defaults["chunk_size"],
        help="Number of policies to generate/save in a single file chunk",
    )
    parser.add_argument(
        "--num_jobs",
        type=int,
        default=defaults["num_jobs"],
        help="Number of CPU jobs to run in parallel",
    )
    parser.add_argument(
        "--stats",
        nargs="+",
        default=defaults["stats"],
        help="Mountain car rewards to evaluate",
    )
    parser.add_argument(
        "--policy_shape",
        type=str,
        default="small",
        choices=["small", "medium", "medium_large", "large"],
        help="Shape of the policy network",
        required=True,
    )
    parser.add_argument(
        "--skip_eval",
        action="store_true",
        help="Skip the evaluation phase (only generate and save policies)",
    )
    parser.add_argument(
        "--load_existing",
        action="store_true",
        help="Load existing policy weights instead of generating new ones",
    )
    return parser


def main() -> None:
    args = prep_arg_parser().parse_args()

    if args.num_policies % args.chunk_size != 0:
        raise ValueError("num_policies must be divisible by chunk_size")

    rng = set_seeds(args.seed)
    # Decouple policy generation from the number of batches. This ensures
    # that the same policies are generated regardless of the batch size.
    policy_generator = np.random.default_rng(rng.integers(0, 2**32 - 1))

    # Prep policy architecture
    policy_args = stored_params["policy"]
    layer_shapes = [
        tuple(item) for item in policy_args["layer_shapes"][args.policy_shape]
    ]

    # Make sure paths are fine: create if not exist, but do not overwrite
    save_path = (
        project_root
        / "generated_policies"
        / f"{stored_params['env']}_{args.policy_shape}_{args.num_policies}_seed_{args.seed}/"
    )
    save_path.mkdir(parents=True, exist_ok=True)

    # Prep chunking and stats tracking
    num_chunks = args.num_policies // args.chunk_size
    min_stats = {stat: float("inf") for stat in args.stats}
    max_stats = {stat: float("-inf") for stat in args.stats}

    chunk_seeds = get_seed_sequence(rng, num_chunks)
    for chunk_idx, chunk_seed in enumerate(chunk_seeds):
        print(f"Processing file chunk {chunk_idx + 1} / {num_chunks}")
        if args.load_existing:
            policies_weights = torch.tensor(
                np.load(f"{save_path}/weights_{chunk_idx}.npy"), dtype=torch.float32
            )
        else:
            policies_weights = torch.tensor(
                policy_generator.uniform(
                    -args.param_range,
                    args.param_range,
                    (args.chunk_size, Policy._count_params_from_shape(layer_shapes)),
                ),
                dtype=torch.float32,
            )
        if not args.skip_eval:
            stats = multi_reward_evaluate(
                env_id=stored_params["env"],
                wrapper_class=WRAPPER_CLASSES[stored_params["env"]],
                seed=chunk_seed,
                layer_shapes=layer_shapes,
                activation_fn=policy_args["activation_func"],
                last_activation_fn=policy_args["last_activation_func"],
                policies_weights=policies_weights,
                stats_to_collect=args.stats,
                input_scaler=INPUT_SCALERS[stored_params["env"]],
                eval_batch_size=args.num_envs,  # num_envs is the batch size
                n_jobs=args.num_jobs,
            )

            for stat in args.stats:
                min_stats[stat] = min(min_stats[stat], float(np.min(stats[stat])))
                max_stats[stat] = max(max_stats[stat], float(np.max(stats[stat])))

        # Save results for this chunk
        if not args.load_existing:
            np.save(f"{save_path}/weights_{chunk_idx}.npy", policies_weights.numpy())
        if not args.skip_eval:
            with open(f"{save_path}/stats_{chunk_idx}.pkl", "wb") as f:
                pickle.dump(stats, f)

    if not args.skip_eval:
        # Print and save overall stats
        print("Min stats:", min_stats)
        print("Max stats:", max_stats)

        with open(f"{save_path}/min_stats.pkl", "wb") as f:
            pickle.dump(min_stats, f)
        with open(f"{save_path}/max_stats.pkl", "wb") as f:
            pickle.dump(max_stats, f)

    # Dump parameters for reproducibility
    params = {
        "policy_shape": args.policy_shape,
        "layer_shapes": [list(l) for l in layer_shapes],
        "activation_func": policy_args["activation_func"],
        "last_activation_func": policy_args["last_activation_func"],
        "env": stored_params["env"],
        "eval_batch_size": args.num_envs,
        "num_jobs": args.num_jobs,
        "num_policies": args.num_policies,
        "chunk_size": args.chunk_size,
        "seed": args.seed,
        "param_range": args.param_range,
    }
    with open(f"{save_path}/args.yml", "w") as f:
        yaml.dump(params, f, sort_keys=False)


if __name__ == "__main__":
    main()
