import os
import random
import numpy as np
import argparse
import torch
from torch import Tensor
from rescue.algorithms.initial_sampling import GenerateInitialSample
from rescue.utils import status
from experiments.config import PROBLEM_NAMES


class InitTrails:
    def __init__(
        self, 
        problem_name: str, 
        seed_list: list[int],
        device: str
    ) -> None:
        """
        Generate initial samples for the given problem.
        And save them to disk.

        Args:
            problem_name (str): The name of the problem to solve.
            seed_list (list[int]): A list of random seeds for reproducibility.
            device (str): The device to save the tensors on (e.g., "cpu" or "cuda").
        """
        # set env before importing exp_config
        os.environ["EXP_DEVICE"] = device
        os.environ["EXP_DTYPE"]  = "double"
        # future me, PROBLEM_CONFIGS must be imported
        # after EXP_DEVICE is set
        from experiments.problems.problems import PROBLEM_CONFIGS

        if problem_name not in PROBLEM_CONFIGS:
            raise ValueError(f"Unknown problem_name: {problem_name}")
        self.c = PROBLEM_CONFIGS[problem_name]
        self.seed_list = seed_list
        self.problem_name = problem_name
        self.init_sampling_MF = GenerateInitialSample(
            problem=self.c.PROBLEM_MF,
            bounds=self.c.BOUNDS_MF,
            has_constraints=self.c.HAS_CONSTRAINTS,
        )
        self.init_sampling_SF = GenerateInitialSample(
            problem=self.c.PROBLEM_SF,
            bounds=self.c.BOUNDS_SF,
            has_constraints=self.c.HAS_CONSTRAINTS,
        )
        self.status_spinner = True

    @status(show_func_name=True)
    def gen_sample_MF(
        self, 
        n_full_fidelity_equiv: float | int
    ) -> dict[int, tuple[Tensor, Tensor, Tensor | None]]:
        """
        Generate initial data for the MF problem.
        Save the data as .pt file.

        Args:
            n_full_fidelity_equiv (float | int): cost(x) * n_full_fidelity_equiv
        """
        init_data = {}
        for seed in self.seed_list:
            set_seed(seed)
            init_data[seed] = (
                self.init_sampling_MF.generate_initial_data_multifidelity(
                    n_full_fidelity_equiv=n_full_fidelity_equiv,
                    cost_fn=self.c.CostFn_MF,
                    is_discrete=self.c.IS_DISCRETE_FIDELITIES,
                    fidelity_levels=self.c.FIDELITY_LEVELS,
            ))
        torch.save(init_data, f"data/init_trails/{self.problem_name}_init_trails_mf.pt")
        # print the init_data size for each seed and saved dir
        for seed in self.seed_list:
            print(f"Initial sample size MF for seed {seed}: {init_data[seed][0].size()}")
        print(
            f"Saved MF init data for {self.problem_name} "
            f"to data/init_trails/{self.problem_name}_init_trails_mf.pt"
        )
        
        return init_data


    @status(show_func_name=True)
    def gen_sample_SF(self, n: int | float) -> tuple[Tensor, Tensor, Tensor | None]:
        """
        Generate initial data for the SF problem.
        Save the data as .pt file.

        Args:
            n (float | int): The number of samples to generate.
        """
        if isinstance(n, float):
            n = int(n)
        init_data = {}
        for seed in self.seed_list:
            set_seed(seed)
            init_data[seed] = (
                self.init_sampling_SF.generate_initial_data(n=n, seed=seed)
            )
        torch.save(init_data, f"data/init_trails/{self.problem_name}_init_trails_sf.pt")
        for seed in self.seed_list:
            print(f"Initial sample size SF for seed {seed}: {init_data[seed][0].size()}")
        print(
            f"Saved SF init data for {self.problem_name} "
            f"to data/init_trails/{self.problem_name}_init_trails_sf.pt"
        )
        return init_data
    
def set_seed(seed: int) -> None:
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
def main():
    problem_choices = PROBLEM_NAMES
    fidelity_choices = ["mf", "sf"]

    parser = argparse.ArgumentParser()
    parser.add_argument("-p", "--problem", required=True, choices=problem_choices)
    parser.add_argument("-f", "--fidelity", required=True, choices=fidelity_choices,
                        help="Specify the fidelity type: "
                        "'mf' for multi-fidelity or 'sf' for single-fidelity.")
    parser.add_argument("-b", "--init_budget", type=float, required=True,
                        help="Initial samples to generate based on the cost function. "
                        "For multi-fidelity, this is the budget in terms cost(x) * n_full_fidelity_equiv."
                        )
    parser.add_argument("-d", "--device", default="cpu")
    parser.add_argument("-s", "--seeds", type=int, nargs="+", required=True,
                        help="List of random seeds for reproducibility.")
    args = parser.parse_args() 

    init_trails = InitTrails(
        problem_name=args.problem, 
        seed_list=args.seeds, 
        device=args.device
    )
    if args.fidelity == "mf":
        init_trails.gen_sample_MF(n_full_fidelity_equiv=args.init_budget)

    if args.fidelity == "sf":
        init_trails.gen_sample_SF(n=args.init_budget)

if __name__ == "__main__":
    main()
