import os

import torch as t
import torch.distributions as td


class BaseSolver(t.nn.Module):
    def forward(
        self,
        states,
        encoded=False,
    ):
        raise NotImplementedError

    def predict(
        self,
        states,
    ):
        return (
            td.Categorical(
                logits=self.forward(states),
            )
            .sample()
            .cpu()
            .numpy()
        )

    def save(
        self,
        file: str,
    ):
        if file is None or file == "":
            print("File name empty!!")
            return
        folder = os.path.dirname(file)
        if not os.path.exists(folder):
            os.makedirs(folder)

        t.save(self.state_dict(), file)

    def load(
        self,
        file,
        strict=True,
        verbose=False,
    ):
        print(f'loading from file: "{file}" with strict matching: {strict}') if verbose else None
        if file is None or file == "":
            print("file name empty.") if verbose else None
            return False
        if not os.path.exists(file):
            print("file does not exist.") if verbose else None
            return False

        try:
            file_data = t.load(file, map_location="cpu")
            if file.endswith(".ckpt"):
                file_data = {
                    k.lstrip("solver."): v
                    for k, v in file_data["state_dict"].items()
                    if not k.startswith("target_encoder.")
                }

            self.load_state_dict(file_data, strict=strict)
            return True
        except RuntimeError as e:
            print(f"failed to load model.\nException: {e}")
            return False
