import json
import os
from typing import Dict, List, Optional
from s3fs.core import S3FileSystem
import numpy as np


class S3FileHandler:
    """Wrapper that gets either an s3 endpoint or None. If an s3 endpoint is provided,
    all methods will read and save the file to the indicated s3 bucket,
    otherwise the local disk is used. The goal is to use the same functions when
    the script is run on server or on a local machine.
    """

    def __init__(self, s3_endpoint: Optional[str], bucket: Optional[str]) -> None:
        self.s3_endpoint = s3_endpoint
        self.bucket = bucket

        # Assert that bucket is in {"input", "output"}. This could change along with
        # The server's infrastructure but is for now structured as an input and ouput
        # bucket.
        # WARNING: if "bucket" == "output", the files will be stored under a folder
        # named after the experiment whereas saving in the "input" bucket is more
        # straightforward as the path is directly used.

        if s3_endpoint:
            self.s3 = S3FileSystem(client_kwargs={"endpoint_url": s3_endpoint})

            if bucket == "input":
                self.bucket_path = os.environ.get("S3_INPUT_PATH")
            elif bucket == "output":
                self.bucket_path = os.environ.get("S3_OUTPUT_PATH")
            else:
                raise ValueError("bucket should be in {input, output}")
        else:
            self.bucket_path = "./"

    # Handling text files
    def read_text(self, path: str) -> List[str]:
        """Wrapper around the python text read method.

        Args:
            path (str): Path to which the text is to be read.

        Returns:
            List[str]: lines read from the text file.
        """
        if self.s3_endpoint:
            with self.s3.open(
                os.path.join(self.bucket_path, path)  # type: ignore
            ) as f:
                lines: List[str] = f.readlines()
        else:
            with open(os.path.join(self.bucket_path, path)) as f:  # type: ignore
                lines = f.readlines()

        return lines

    def save_text(self, path: str, lines: List[str]) -> None:
        """Wrapper around the python text save method.

        Args:
            path (str): Path to which the text is to be saved.
            lines (np.ndarray): Lines to be written in the text file.
        """
        if self.s3_endpoint:
            with self.s3.open(
                os.path.join(self.bucket_path, path), "w"  # type: ignore
            ) as f:
                for line in lines:
                    f.write(line)
        else:
            with open(os.path.join(self.bucket_path, path), "w") as f:  # type: ignore
                for line in lines:
                    f.write(line)

    # Handling json files
    def read_json(self, path: str) -> Dict:
        """Wrapper around the python json load method

        Args:
            path (str): Path to which the text is to be saved.

        Returns:
            Dict: Dictionary written in the json file;
        """
        if self.s3_endpoint:
            with self.s3.open(
                os.path.join(self.bucket_path, path)  # type: ignore
            ) as f:
                data: Dict = json.load(f)
        else:
            with open(os.path.join(self.bucket_path, path)) as f:  # type: ignore
                data = json.load(f)

        return data

    def save_json(self, path: str, data: Dict) -> None:
        """Wrapper around the python json save method

        Args:
            path (str): Path to which the text is to be saved.
            data (Dict): Dictionary to be saved.
        """
        if self.s3_endpoint:
            with self.s3.open(
                os.path.join(self.bucket_path, path), "w"  # type: ignore
            ) as f:
                json.dump(data, f)
        else:
            with open(os.path.join(self.bucket_path, path), "w") as f:  # type: ignore
                json.dump(data, f)

    # Handling os operations
    def listdir(self, path: str) -> List[str]:
        """Wrapper around the listdir command.

        Args:
            path (str): Path to the folder that needs to be inspected

        Returns:
            List[str]: List of filenames in the folder.
        """
        if self.s3_endpoint:
            # Gets the list of paths from root for all files in the folder
            list_files = [
                file
                for file in self.s3.ls(
                    os.path.join(self.bucket_path, path)  # type: ignore
                )
            ]
            # Trim paths to get only the file names as in os.listdir
            list_files = [file.split(os.path.sep)[-1] for file in list_files]
            return list_files

        else:
            return os.listdir(path=path)


# Make sure all data in logger.sacred info is json serializable
def make_json_serializable(obj):
    if isinstance(obj, dict):
        return {k: make_json_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [make_json_serializable(v) for v in obj]
    elif isinstance(obj, tuple):
        return tuple(make_json_serializable(v) for v in obj)
    elif isinstance(obj, set):
        return list(obj)
    elif isinstance(obj, np.ndarray):  # Check for NumPy or JAX NumPy arrays
        return obj.tolist()
    elif hasattr(obj, "__dict__"):  # Check if the object is a custom class instance
        return make_json_serializable(obj.__dict__)
    else:
        try:
            json.dumps(obj)  # Test if the object is JSON serializable
            return obj
        except TypeError:
            return str(
                obj
            )  # Convert the non-serializable object to a string representation
