import numpy as np
import pgx

from src.env.grid_risk_v2 import GridRiskV2
from src.env.risk_space_invaders_v2 import RiskMinAtarSpaceInvadersV2
from src.env.stochastic_bipartite_matching import (
    StochasticBipartiteMatching,
    batch_sbm_instances,
    load_sbm_instance,
)
from src.env.stochastic_max_ind_set import (
    StochasticMaxIndependentSet,
    batch_stochastic_mis_instances,
    load_stochastic_mis_instance,
)


def make_env(env_name: str, use_legal_actions: bool = False):
    if env_name == "grid-risk-v2":
        env = GridRiskV2(use_legal_actions=use_legal_actions)
    elif env_name == "space-invaders-risk-v2-2":
        env = RiskMinAtarSpaceInvadersV2()
    elif env_name == "stochastic-bipartite-matching-3020":
        folder_path = "datasets/stochastic_bm/instances_30_20_180_final"
        # instance_ids = np.arange(0, 32)
        # instance_ids = list(range(2)) * 16
        # instance_ids = [0] * 32

        instance_ids = np.arange(0, 1024)

        even = instance_ids[instance_ids % 2 == 0]
        odd = instance_ids[instance_ids % 2 == 1]
        instance_ids = np.concatenate([even, odd])

        print("Loading instances...")
        loaded_instances = batch_sbm_instances(
            [
                load_sbm_instance(instance_id, folder_path)
                for instance_id in instance_ids
            ]
        )
        print("Finished loading instances.")

        env = StochasticBipartiteMatching(instances=loaded_instances)
    elif env_name == "stochastic-bipartite-matching-301030":
        folder_path = "datasets/stochastic_bm/instances_30_10_30"
        # instance_ids = np.arange(0, 32)
        # instance_ids = list(range(2)) * 16
        # instance_ids = [0] * 32

        instance_ids = np.arange(0, 1024)

        even = instance_ids[instance_ids % 2 == 0]
        odd = instance_ids[instance_ids % 2 == 1]
        instance_ids = np.concatenate([even, odd])

        print("Loading instances...")
        loaded_instances = batch_sbm_instances(
            [
                load_sbm_instance(instance_id, folder_path)
                for instance_id in instance_ids
            ]
        )
        print("Finished loading instances.")

        env = StochasticBipartiteMatching(instances=loaded_instances)
    elif env_name == "stochastic-bipartite-matching-301060":
        folder_path = "datasets/stochastic_bm/instances_30_10_60"
        # instance_ids = np.arange(0, 32)
        # instance_ids = list(range(2)) * 16
        # instance_ids = [0] * 32

        instance_ids = np.arange(0, 1024)

        even = instance_ids[instance_ids % 2 == 0]
        odd = instance_ids[instance_ids % 2 == 1]
        instance_ids = np.concatenate([even, odd])

        print("Loading instances...")
        loaded_instances = batch_sbm_instances(
            [
                load_sbm_instance(instance_id, folder_path)
                for instance_id in instance_ids
            ]
        )
        print("Finished loading instances.")

        env = StochasticBipartiteMatching(instances=loaded_instances)
    elif env_name == "stochastic-bipartite-matching-6030180":
        folder_path = "datasets/stochastic_bm/instances_60_30_180"
        # instance_ids = np.arange(0, 32)
        # instance_ids = list(range(2)) * 16
        # instance_ids = [0] * 32

        instance_ids = np.arange(0, 1024)

        even = instance_ids[instance_ids % 2 == 0]
        odd = instance_ids[instance_ids % 2 == 1]
        instance_ids = np.concatenate([even, odd])

        print("Loading instances...")
        loaded_instances = batch_sbm_instances(
            [
                load_sbm_instance(instance_id, folder_path)
                for instance_id in instance_ids
            ]
        )
        print("Finished loading instances.")

        env = StochasticBipartiteMatching(instances=loaded_instances)
    elif env_name == "stochastic-bipartite-matching-6030270":
        folder_path = "datasets/stochastic_bm/instances_60_30_270"
        # instance_ids = np.arange(0, 32)
        # instance_ids = list(range(2)) * 16
        # instance_ids = [0] * 32

        instance_ids = np.arange(0, 1024)

        even = instance_ids[instance_ids % 2 == 0]
        odd = instance_ids[instance_ids % 2 == 1]
        instance_ids = np.concatenate([even, odd])

        print("Loading instances...")
        loaded_instances = batch_sbm_instances(
            [
                load_sbm_instance(instance_id, folder_path)
                for instance_id in instance_ids
            ]
        )
        print("Finished loading instances.")

        env = StochasticBipartiteMatching(instances=loaded_instances)
    elif env_name == "stochastic-bipartite-matching-6030360":
        folder_path = "datasets/stochastic_bm/instances_60_30_360"
        # instance_ids = np.arange(0, 32)
        # instance_ids = list(range(2)) * 16
        # instance_ids = [0] * 32

        instance_ids = np.arange(0, 1024)

        even = instance_ids[instance_ids % 2 == 0]
        odd = instance_ids[instance_ids % 2 == 1]
        instance_ids = np.concatenate([even, odd])

        print("Loading instances...")
        loaded_instances = batch_sbm_instances(
            [
                load_sbm_instance(instance_id, folder_path)
                for instance_id in instance_ids
            ]
        )
        print("Finished loading instances.")

        env = StochasticBipartiteMatching(instances=loaded_instances)
    elif env_name == "stochastic-bipartite-matching-6030180ba":
        folder_path = "datasets/stochastic_bm/instances_60_30_180_ba"
        # instance_ids = np.arange(0, 32)
        # instance_ids = list(range(2)) * 16
        # instance_ids = [0] * 32

        instance_ids = np.arange(0, 1024)

        even = instance_ids[instance_ids % 2 == 0]
        odd = instance_ids[instance_ids % 2 == 1]
        instance_ids = np.concatenate([even, odd])

        print("Loading instances...")
        loaded_instances = batch_sbm_instances(
            [
                load_sbm_instance(instance_id, folder_path)
                for instance_id in instance_ids
            ]
        )
        print("Finished loading instances.")

        env = StochasticBipartiteMatching(instances=loaded_instances)
    elif env_name == "stochastic-max-ind-set-1":
        # Load instances from the dataset
        folder_path = "./datasets/stochastic_mis/instances_30_65"
        instance_ids = np.arange(0, 1024)

        even = instance_ids[instance_ids % 2 == 0]
        odd = instance_ids[instance_ids % 2 == 1]
        instance_ids = np.concatenate([even, odd])

        print("Loading instances...")
        loaded_instances = batch_stochastic_mis_instances(
            [
                load_stochastic_mis_instance(instance_id, folder_path)
                for instance_id in instance_ids
            ]
        )
        print("Finished loading instances.")

        env = StochasticMaxIndependentSet(instances=loaded_instances)  # type: ignore
    elif env_name == "stochastic-max-ind-set-30130":
        # Load instances from the dataset
        folder_path = "./datasets/stochastic_mis/instances_30_130"
        instance_ids = np.arange(0, 1024)

        even = instance_ids[instance_ids % 2 == 0]
        odd = instance_ids[instance_ids % 2 == 1]
        instance_ids = np.concatenate([even, odd])

        print("Loading instances...")
        loaded_instances = batch_stochastic_mis_instances(
            [
                load_stochastic_mis_instance(instance_id, folder_path)
                for instance_id in instance_ids
            ]
        )
        print("Finished loading instances.")

        env = StochasticMaxIndependentSet(instances=loaded_instances)  # type: ignore
    elif env_name == "stochastic-max-ind-set-30174":
        # Load instances from the dataset
        folder_path = "./datasets/stochastic_mis/instances_30_174"
        instance_ids = np.arange(0, 1024)

        even = instance_ids[instance_ids % 2 == 0]
        odd = instance_ids[instance_ids % 2 == 1]
        instance_ids = np.concatenate([even, odd])

        print("Loading instances...")
        loaded_instances = batch_stochastic_mis_instances(
            [
                load_stochastic_mis_instance(instance_id, folder_path)
                for instance_id in instance_ids
            ]
        )
        print("Finished loading instances.")

        env = StochasticMaxIndependentSet(instances=loaded_instances)  # type: ignore
    elif env_name == "stochastic-max-ind-set-60354":
        # Load instances from the dataset
        folder_path = "./datasets/stochastic_mis/instances_60_354"
        instance_ids = np.arange(0, 1024)

        even = instance_ids[instance_ids % 2 == 0]
        odd = instance_ids[instance_ids % 2 == 1]
        instance_ids = np.concatenate([even, odd])

        print("Loading instances...")
        loaded_instances = batch_stochastic_mis_instances(
            [
                load_stochastic_mis_instance(instance_id, folder_path)
                for instance_id in instance_ids
            ]
        )
        print("Finished loading instances.")

        env = StochasticMaxIndependentSet(instances=loaded_instances)  # type: ignore
    elif env_name == "stochastic-max-ind-set-60531":
        # Load instances from the dataset
        folder_path = "./datasets/stochastic_mis/instances_60_531"
        instance_ids = np.arange(0, 1024)

        even = instance_ids[instance_ids % 2 == 0]
        odd = instance_ids[instance_ids % 2 == 1]
        instance_ids = np.concatenate([even, odd])

        print("Loading instances...")
        loaded_instances = batch_stochastic_mis_instances(
            [
                load_stochastic_mis_instance(instance_id, folder_path)
                for instance_id in instance_ids
            ]
        )
        print("Finished loading instances.")

        env = StochasticMaxIndependentSet(instances=loaded_instances)  # type: ignore
    elif env_name == "stochastic-max-ind-set-60708":
        # Load instances from the dataset
        folder_path = "./datasets/stochastic_mis/instances_60_708"
        instance_ids = np.arange(0, 1024)

        even = instance_ids[instance_ids % 2 == 0]
        odd = instance_ids[instance_ids % 2 == 1]
        instance_ids = np.concatenate([even, odd])

        print("Loading instances...")
        loaded_instances = batch_stochastic_mis_instances(
            [
                load_stochastic_mis_instance(instance_id, folder_path)
                for instance_id in instance_ids
            ]
        )
        print("Finished loading instances.")

        env = StochasticMaxIndependentSet(instances=loaded_instances)  # type: ignore
    elif env_name == "stochastic-max-ind-set-60240ba":
        # Load instances from the dataset
        folder_path = "./datasets/stochastic_mis/instances_60_240_ba"
        instance_ids = np.arange(0, 1024)

        even = instance_ids[instance_ids % 2 == 0]
        odd = instance_ids[instance_ids % 2 == 1]
        instance_ids = np.concatenate([even, odd])

        print("Loading instances...")
        loaded_instances = batch_stochastic_mis_instances(
            [
                load_stochastic_mis_instance(instance_id, folder_path)
                for instance_id in instance_ids
            ]
        )
        print("Finished loading instances.")

        env = StochasticMaxIndependentSet(instances=loaded_instances)  # type: ignore
    elif env_name == "stochastic-max-ind-set-50600":
        # Load instances from the dataset
        folder_path = "./datasets/stochastic_mis/instances_100_500"
        instance_ids = np.arange(0, 1024)

        even = instance_ids[instance_ids % 2 == 0]
        odd = instance_ids[instance_ids % 2 == 1]
        instance_ids = np.concatenate([even, odd])

        print("Loading instances...")
        loaded_instances = batch_stochastic_mis_instances(
            [
                load_stochastic_mis_instance(instance_id, folder_path)
                for instance_id in instance_ids
            ]
        )
        print("Finished loading instances.")

        env = StochasticMaxIndependentSet(instances=loaded_instances)  # type: ignore
    else:
        env = pgx.make(env_name)  # type: ignore

    is_state_vector = env_name in [
        "grid-risk",
        "grid-risk-v2",
        "grid",
        "tree-risk",
        "mountain-car-risk",
        "mountain-car-c-risk-full",
        "mountain-car-cd-risk",
        "cartpole",
    ]
    return env, is_state_vector, env.num_actions
