from typing import overload
import torch
import os

fallback_root_registered = False
token_opt = None
norm_opt = None
output_opt = None
transformer_opts = []
last_nums_transformer = []
cur_nums_transformer = []
last_num_norm = 0
cur_num_norm = 0
last_num_token = 0
cur_num_token = 0

class GDS_facility:
    begin_from_cpu = False
    enable_teraio = True
    enable_interleaved_optimizer = True
    backward_hook_function_transformers = []

    @classmethod
    def backward_hook_function_transformer(cls, id):
        global transformer_opts
        global last_nums_transformer
        global last_nums_transformer
        if transformer_opts[id] is not None:
            if cur_nums_transformer[id] == last_nums_transformer[id]:
                transformer_opts[id].step()
                transformer_opts[id].zero_grad()
                cur_nums_transformer[id] = 0
            else:
                cur_nums_transformer[id] = cur_nums_transformer[id] + 1

    @classmethod
    def backward_hook_function_tok(cls, module, input_grad, outputgrad):
        global token_opt
        global cur_num_token
        global last_num_token
        if token_opt is not None:
            if cur_num_token == last_num_token:
                token_opt.step()
                token_opt.zero_grad()
                cur_num_token = 0
            else:
                cur_num_token = cur_num_token + 1

    @classmethod
    def backward_hook_function_norm(cls, module, input_grad, outputgrad):
        global norm_opt
        global cur_num_norm
        global last_num_norm
        if norm_opt is not None:
            if cur_num_norm == last_num_norm:
                norm_opt.step()
                norm_opt.zero_grad()
                cur_num_norm = 0
            else:
                cur_num_norm = cur_num_norm + 1

    @classmethod
    def backward_hook_function_output(cls, module, input_grad, outputgrad):
        return

    @classmethod
    def set_backward_token(cls, opt):
        global token_opt
        token_opt = opt

    @classmethod
    def set_backward_norm(cls, opt):
        global norm_opt
        norm_opt = opt
    
    @classmethod
    def set_backward_output(cls, opt):
        global output_opt
        output_opt = opt

    @classmethod
    def set_backward_transformer(cls, id, opt):
        global transformer_opts
        transformer_opts[id] = opt

    @classmethod
    def init_rank(cls, rank: int, streams_per_ssd: int, io_threads_per_ssd:int) -> None:
        return torch._C._cuda_gds_liveness_ctl(f"setRank|{rank}|{cls.begin_from_cpu}|{streams_per_ssd}|{io_threads_per_ssd}")

    @classmethod
    def __register_fallback_root(cls, abs_path: str) -> None:
        deviceIdx = torch._C._cuda_gds_register_storage_dev(abs_path)
        assert deviceIdx >= 0, f"Invalid storage path {abs_path}"
        return deviceIdx

    @classmethod
    def init_GDS(cls, fallback_root: str, verbose=False) -> bool:
        cls.__gds_stats = torch._C._cuda_gds_setup(verbose)
        return cls.__register_fallback_root(fallback_root)

    @classmethod
    def register_files_root(cls, abs_path: str) -> int:
        return torch._C._cuda_gds_register_storage_dev(abs_path)

    """
        Register a tensor using device index, file index, offset, and size
    """
    @classmethod
    def register_tensor(cls, devIdx: int, file: int, offset: int, size: int) -> bool:
        pass

    """
        Register a tensor using device index, file relative path, offset, and size
    """
    @classmethod
    def register_tensor(cls, devIdx: int, rel_path: str, offset: int, size: int) -> bool:
        pass

    """
        Register a tensor using file absolute path, offset, and size
        @note: The method will fail if the path is not a file in any of the
               registered roots. Method not recommended to be used, will perform
               a linear search on all registered roots
    """
    @classmethod
    def register_tensor(cls, abs_path: str, offset: int, size: int) -> bool:
        cls.register_file(abs_path)
        pass

    @classmethod
    def get_filename(cls, dev: int, file: int) -> str:
        return torch._C._cuda_gds_find_register_file(dev, file)
    
    @classmethod
    def get_tensor_id(cls, tensor: torch.Tensor):
        pass

    @classmethod
    def associate_tensor_with_file(cls,
                                   tensor: torch.Tensor,
                                   devIdx: int, fileIdx: int, offset: int, size: int):
        return torch._C._cuda_gds_associate_tensor_with_file(tensor, devIdx, fileIdx, offset, size)

    @classmethod
    def register_tensor_with_file(cls,
                                  tensor: torch.Tensor,
                                  devIdx: int, path: str, offset: int, size: int):
        fileIdx: int = cls.register_file(devIdx, path)
        return cls.associate_tensor_with_file(tensor, devIdx, fileIdx, offset, size)
    
    @classmethod
    def register_tensor_with_file(cls,
                                  tensor: torch.Tensor,
                                  devIdx: int, path: str, size: int):
        return cls.register_tensor_with_file(tensor, devIdx, path, -1, size)

    @classmethod
    def register_tensor_with_file(cls,
                                  tensor: torch.Tensor,
                                  abs_path: str, offset: int, size: int):
        pass

    @classmethod
    def deregister_tensor_with_file(cls, tensor: torch.Tensor):
        pass

    @classmethod
    def start_profile_run(cls):
        return torch._C._cuda_gds_liveness_ctl("startProfile")

    @classmethod
    def end_profile_run(cls):
        return torch._C._cuda_gds_liveness_ctl("endProfile")

    @classmethod
    def start_liveness_record(cls):
        torch._C._cuda_gds_liveness_ctl("startRecord")

    @classmethod
    def end_liveness_record(cls):
        torch._C._cuda_gds_liveness_ctl("endRecord")

    @classmethod
    def start_replay(cls):
        torch._C._cuda_gds_liveness_ctl("startReplay")

    @classmethod
    def end_replay(cls):
        torch._C._cuda_gds_liveness_ctl("endReplay")

    @classmethod
    def start_load_mode(cls, sync_prior_migration=False):
        torch._C._cuda_gds_liveness_ctl(f"startLoad|{sync_prior_migration}")

    @classmethod
    def end_load_mode(cls, sync_prior_migration=False):
        torch._C._cuda_gds_liveness_ctl(f"endLoad|{sync_prior_migration}")

    @classmethod
    def mark_forward_end(cls):
        torch._C._cuda_gds_liveness_ctl("forwardEnd")

    @classmethod
    def mark_iter_end(cls):
        return torch._C._cuda_gds_liveness_ctl("iterEnd")

    @classmethod
    def generate_migration_plan(cls, output_file: str | None=None):
        if output_file is None:
            torch._C._cuda_gds_liveness_ctl("genMigration")
        else:
            torch._C._cuda_gds_liveness_ctl(f"genMigration|{output_file}")

    @classmethod
    def reset_liveness_record(cls):
        torch._C._cuda_gds_liveness_ctl("resetLiveness")

    @classmethod
    def reset_replay_record(cls):
        torch._C._cuda_gds_liveness_ctl("resetReplay")

    @classmethod
    def serialize_liveness(cls, iter, verbose=False, output_file: str | None=None):
        if output_file is None:
            torch._C._cuda_gds_liveness_ctl(f"serializeLiveness|{iter}|{int(verbose)}")
        else:
            torch._C._cuda_gds_liveness_ctl(f"serializeLiveness|{iter}|{int(verbose)}|{os.path.abspath(output_file)}")

    @classmethod
    def release_involved_tensors(cls, iter):
        torch._C._cuda_gds_liveness_ctl(f"releaseInv|{iter}")
    
    @classmethod
    def release_trl(cls):
        torch._C._cuda_gds_liveness_ctl(f"releaseTRL")
        return
    
    @classmethod
    def set_print_info(cls, enable):
        torch._C._cuda_gds_liveness_ctl(f"setPrintInfo|{enable}")
        return
    
    @classmethod
    def set_emu_bandwidth(cls, rw, bw):
        torch._C._cuda_gds_liveness_ctl(f"setEmuBandwidth|{rw}|{bw}")
        return
    
    @classmethod
    def reset_storage_alloc(cls):
        torch._C._cuda_gds_liveness_ctl(f"resetStorageAlloc")
        return
    
    @classmethod
    def set_skip_list(cls, ls):
        torch._C._cuda_gds_liveness_ctl(f"setSkipList|{ls}")
        return

    @classmethod
    def deserialize_migration_plan(cls, input_file):
        torch._C._cuda_gds_liveness_ctl(f"readMigration|{os.path.abspath(input_file)}")

    @classmethod
    def read_profile_reload_plan(cls, input_file):
        torch._C._cuda_gds_liveness_ctl(f"readProfileReload|{os.path.abspath(input_file)}")

    @classmethod
    def load_migration_plan_from_teraio_external_executable(cls, liveness_info_file, generate, rank, migration_plan_path):
        import shutil
        teraio_algo_src_dir = f"/home/{os.getlogin()}/teraio-algorithm/src"
        teraio_algo_exe = os.path.join(teraio_algo_src_dir, f"{rank}/teraio-algorithm")
        teraio_algo_args = os.path.join(teraio_algo_src_dir, "configs", f"demo_rank{rank}.config")

        if generate:
            target_input_file = os.path.join(teraio_algo_src_dir, f"{rank}/semantics.in")
            src_input_file = os.path.join(os.path.abspath(liveness_info_file))
            assert os.path.isfile(src_input_file)
            if os.path.isfile(target_input_file):
                os.remove(target_input_file)
            shutil.copy(src_input_file, target_input_file)

            import subprocess
            p = subprocess.Popen([teraio_algo_exe, teraio_algo_args],
                                stdout=subprocess.DEVNULL,
                                cwd=os.path.join(teraio_algo_src_dir, f"{rank}"))
            p.wait()

        migration_plan_path = migration_plan_path.replace('X', f"{rank}")
        target_output_file = os.path.join(migration_plan_path, "migration_plan.txt")
        assert os.path.isfile(target_output_file)

        cls.deserialize_migration_plan(target_output_file)
