import os
import random
import numpy as np
import argparse
import torch
from torch import Tensor
from rescue.algorithms.initial_sampling import GenerateInitialSample
from rescue.utils import status
from experiments.config import PROBLEM_NAMES

class ObservationalDataGenerator:
    def __init__(
        self, 
        problem_name: str, 
        seed_list: list[int],
        device: str
    ) -> None:
        """
        Generate initial samples for the given problem.
        And save them to disk.

        Args:
            problem_name (str): The name of the problem to solve.
            seed_list (list[int]): A list of random seeds for reproducibility.
            device (str): The device to save the tensors on (e.g., "cpu" or "cuda").
        """
        # set env before importing exp_config
        os.environ["EXP_DEVICE"] = device
        os.environ["EXP_DTYPE"]  = "double"
        # future me, PROBLEM_CONFIGS requires to be imported
        # after EXP_DEVICE is set
        from experiments.problems.problems import PROBLEM_CONFIGS

        if problem_name not in PROBLEM_CONFIGS:
            raise ValueError(f"Unknown problem_name: {problem_name}")
        self.c = PROBLEM_CONFIGS[problem_name]
        self.seed_list = seed_list
        self.problem_name = problem_name
        self.init_sampling_MF = GenerateInitialSample(
            problem=self.c.PROBLEM_MF,
            bounds=self.c.BOUNDS_MF,
            has_constraints=self.c.HAS_CONSTRAINTS,
        )
        self.init_sampling_SF = GenerateInitialSample(
            problem=self.c.PROBLEM_SF,
            bounds=self.c.BOUNDS_SF,
            has_constraints=self.c.HAS_CONSTRAINTS,
        )
        self.status_spinner = True

    @status(show_func_name=True)
    def gen_observational_data_MF(
        self, 
        n_full_fidelity_equiv: float | int,
        verbose: bool = False
    ) -> dict[int, tuple[Tensor, Tensor, Tensor | None]]:
        """
        Generate initial data for the MF problem.
        Save the data as .pt file.

        Args:
            n_full_fidelity_equiv (float | int): cost(x) * n_full_fidelity_equiv
        """
        obs_data = {}
        for seed in self.seed_list:
            set_seed(seed)
            obs_data[seed] = (
                self.init_sampling_MF.generate_initial_data_multifidelity(
                    n_full_fidelity_equiv=n_full_fidelity_equiv,
                    cost_fn=self.c.CostFn_MF,
                    is_discrete=self.c.IS_DISCRETE_FIDELITIES,
                    fidelity_levels=self.c.FIDELITY_LEVELS,
                    verbose=verbose
            ))
        torch.save(obs_data, f"data/observational_data/{self.problem_name}_obs_data_mf.pt")
        for seed in self.seed_list:
            print(f"Observational sample size MF for seed {seed}: {obs_data[seed][0].size()}")
        print(
            f"Saved MF observational data for {self.problem_name} "
            f"to data/observational_data/{self.problem_name}_obs_data_mf.pt"
        )
        return obs_data


    @status(show_func_name=True)
    def gen_observational_data_SF(self, n: int | float) -> tuple[Tensor, Tensor, Tensor | None]:
        """
        Generate initial data for the SF problem.
        Save the data as .pt file.

        Args:
            n (float | int): The number of samples to generate.
        """
        if isinstance(n, float):
            n = int(n)
        obs_data = {}
        for seed in self.seed_list:
            set_seed(seed)
            obs_data[seed] = (
                self.init_sampling_SF.generate_initial_data(n=n, seed=seed)
            )
        torch.save(obs_data, f"data/observational_data/{self.problem_name}_obs_data_sf.pt")
        for seed in self.seed_list:
            print(f"Observational sample size SF for seed {seed}: {obs_data[seed][0].size()}")
        print(
            f"Saved SF observational data for {self.problem_name} "
            f"to data/observational_data/{self.problem_name}_obs_data_sf.pt"
        )
        return obs_data

def set_seed(seed: int) -> None:
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def main():
    problem_choices = PROBLEM_NAMES
    fidelity_choices = ["mf", "sf"]

    parser = argparse.ArgumentParser()
    parser.add_argument("-p", "--problem", required=True, choices=problem_choices)
    parser.add_argument("-f", "--fidelity", required=True, choices=fidelity_choices,
                        help="Specify the fidelity type: "
                        "'mf' for multi-fidelity or 'sf' for single-fidelity.")
    parser.add_argument("-b", "--obs_budget", type=float, required=True,
                        help="Observational samples to generate based on the cost function. "
                        "For multi-fidelity, this is the budget in terms cost(x) * n_full_fidelity_equiv."
                        )
    parser.add_argument("-d", "--device", default="cpu")
    parser.add_argument("-s", "--seeds", type=int, nargs="+", required=True,
                        help="List of random seeds for reproducibility.")
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output.")
    args = parser.parse_args()

    init_trails = ObservationalDataGenerator(
        problem_name=args.problem, 
        seed_list=args.seeds, 
        device=args.device
    )
    if args.fidelity == "mf":
        init_trails.gen_observational_data_MF(n_full_fidelity_equiv=args.obs_budget,
                                                verbose=args.verbose
                                            )

    if args.fidelity == "sf":
        init_trails.gen_observational_data_SF(n=args.obs_budget)

if __name__ == "__main__":
    main()
