import os
import shutil
import pandas as pd
import pickle


class Logger:
    def __init__(
        self,
        local_logging_path: str,
        hpo_keys: dict,
        remote_logging_path: str = None,
    ) -> None:
        self.hpo_keys = hpo_keys

        self.local_logging_path = os.path.join(local_logging_path, "logs")
        os.makedirs(self.local_logging_path, exist_ok=True)

        self.use_gcp = False
        self.remote_logging_path = None
        if remote_logging_path is not None:
            self.remote_logging_path = remote_logging_path
            if remote_logging_path.startswith("gs:/"):
                self.use_gcp = True

    def read_logs(self):
        """It enables to relauch an experiment after a system failure"""

        log_files = [
            file
            for file in os.listdir(self.local_logging_path)
            if file.startswith("logs")
        ]

        if len(log_files) == 0:
            # The first round didn't complete, so we launch xp from start
            if os.path.exists(os.path.join(self.local_logging_path, "round_0")):
                shutil.rmtree(os.path.join(self.local_logging_path, "round_0"))
            return False, None, None, None

        round_index = max(
            map(lambda x: int((x.split("_")[-1]).split(".")[0]), log_files)
        )
        log_df = pd.read_csv(
            os.path.join(self.local_logging_path, f"logs_round_{round_index}.csv")
        )

        agent_files = os.listdir(
            os.path.join(self.local_logging_path, f"round_{round_index}")
        )
        agent_params = {}
        for file in agent_files:
            global_index = int((file.split("_")[1]).split(".")[0])
            file_path = os.path.join(
                self.local_logging_path, f"round_{round_index}", file
            )
            params = pickle.load(open(file_path, "rb"))
            agent_params.update({global_index: params})

        return True, round_index, log_df, agent_params

    def write_logs(self, agents, round_index: int):
        agent_keys = [key for key in vars(agents[0]).keys() if "param" not in key]
        log_keys = agent_keys + self.hpo_keys
        log_df = {key: [] for key in log_keys}

        for agent in agents:
            for key in agent_keys:
                log_df[key].append(agent.__getattribute__(key))
            for hp_name in self.hpo_keys:
                log_df[hp_name].append(agent.hyperparameters[hp_name])

        log_df = pd.DataFrame(log_df)
        log_df.to_csv(
            os.path.join(self.local_logging_path, f"logs_round_{round_index}.csv"),
            index=False,
        )

        # We save locally all the agents to be able to reload the xp
        self._save_params(agents, round_index)

        # But to not overload our remote storage we should send only the best policy
        best_agent = int(log_df.iloc[log_df["reward"].argmax()].loc["index"])

        if self.remote_logging_path is not None:
            self._write_to_remote(round_index, best_agent)

        if round_index >= 1:
            self._remove_previous_round(round_index)

    def _save_params(self, agents, round_index):
        for global_index, agent in enumerate(agents):
            save_path = os.path.join(self.local_logging_path, f"round_{round_index}")

            os.makedirs(save_path, exist_ok=True)
            pickle.dump(
                agent.params,
                open(os.path.join(save_path, f"agent_{global_index}.pkl"), "wb"),
            )

    def _write_to_remote(self, round_index: int, best_agent: int):
        # Copy the csv log file
        source = os.path.join(self.local_logging_path, f"logs_round_{round_index}.csv")
        destination = os.path.join(
            self.remote_logging_path, f"round_{round_index}", "log.csv"
        )
        if self.use_gcp:
            os.system(f"gsutil cp {source} {destination}")
        else:
            os.makedirs(os.path.split(destination)[0], exist_ok=True)
            shutil.copy(source, destination)

        # Also send the params in case we want to load the agent
        source = os.path.join(
            self.local_logging_path, f"round_{round_index}", f"agent_{best_agent}.pkl"
        )
        destination = os.path.join(
            self.remote_logging_path, f"round_{round_index}", "best_policy.pkl"
        )

        if self.use_gcp:
            os.system(f"gsutil cp {source} {destination}")
        else:
            os.makedirs(os.path.split(destination)[0], exist_ok=True)
            shutil.copy(source, destination)

        # Save everything every 100 rounds just in case:

        if round_index % 100 == 0:
            source = os.path.join(self.local_logging_path, f"round_{round_index}")
            destination = os.path.join(self.remote_logging_path, f"round_{round_index}")

            if self.use_gcp:
                os.system(f"gsutil -m rsync -r {source} {destination}")
            else:
                os.makedirs(os.path.split(destination)[0], exist_ok=True)
                shutil.copytree(source, destination, dirs_exist_ok=True)

    def _remove_previous_round(self, round_index: int):
        "Trick to gain some disk space, keep only the models needed to reload xp"
        try:
            shutil.rmtree(
                os.path.join(self.local_logging_path, f"round_{round_index - 1}")
            )
        except FileNotFoundError:
            return None
