import os
import shutil
import sys
import traceback
from contextlib import redirect_stderr, redirect_stdout
from copy import deepcopy
from dataclasses import dataclass, field
from tempfile import NamedTemporaryFile
from typing import Callable, Dict, List, Optional, Tuple, TypeVar

import torch
from pathos.pools import ProcessPool, SerialPool
from torch.utils.data import Dataset
from tqdm.auto import tqdm

from ..metadata import BinaryPathUpdater, Metadata
from ..utils import stdout_redirected

T = TypeVar('T')

loader_cache_size = int(os.getenv('TORCHMALWARE_LOADER_CACHE_SIZE', '100'))

def _move_and_makedirs(src: str, dst: str):
    dst_dir = os.path.dirname(dst)
    if not os.path.isdir(dst_dir):
        os.makedirs(dst_dir)
    shutil.move(src, dst)


def save_dataset(
    dataset: Dataset[T], 
    path: Callable[[int, T], str], 
    writer: Optional[Callable[[T, str], None]] = None, 
    num_workers: int = 0, 
    log: bool = False,
    redirect_fd: bool = True
) -> List[Optional[str]]:
    """Save a dataset to disk

    Note:
    Each element of the dataset is saved in a separate file.
    
    Args:
        dataset: Dataset to save to disk. Must use integer keys for indexing.
        path: A function that takes the index of an element in `dataset`, and returns the path to a file where it 
            should be saved.
    
    Keyword args:
        writer: A function that writes an element of `dataset` to disk given its value and path. If not specified, 
            defaults to `torch.save`.
        num_workers: Number of subprocesses to use. A value of `0` means the data will be read/cached in the main 
            process.
        log: Whether to redirect stdout/stderr for each element to a file on disk (saved at path + '.stdout' and 
            path + '.stderr').
        redirect_fd: Whether to redirect stdout/stderr at the file descriptor level.
    
    Returns:
        A list of paths to the saved files for each element in `dataset`. If a path is None, the corresponding element 
            was not successfully saved to disk.
    """
    if writer is None:
        writer = torch.save
    
    if redirect_fd:
        stdout_redirect = lambda f: stdout_redirected(f, stdout = sys.__stdout__)
        stderr_redirect = lambda f: stdout_redirected(f, stdout = sys.__stderr__)
    else:
        stdout_redirect = redirect_stdout
        stderr_redirect = redirect_stderr

    def _cache_instance(idx: int):
        save_path = path(idx)
        try:
            elem = dataset[idx]
            writer(elem, save_path)
        except Exception as e:
            traceback.print_exc()        
        return save_path
    
    def _cache_instance_log(idx: int):
        with NamedTemporaryFile(mode="w", delete=False) as f_out, stdout_redirect(f_out):
            with NamedTemporaryFile(mode="w", delete=False) as f_err, stderr_redirect(f_err):
                save_path = _cache_instance(idx)
                stderr_path = f_err.name
                stdout_path = f_out.name
        for log_path, log_ext in ((stderr_path, '.stderr'), (stdout_path, '.stdout')):
            if os.path.getsize(log_path) == 0:
                # Remove log if empty
                os.remove(log_path)
            else:
                # Move alongside instance
                new_log_path = save_path + log_ext
                _move_and_makedirs(log_path, new_log_path)
        return save_path
    
    cache_instance = _cache_instance if not log else _cache_instance_log
    
    if num_workers >= 1:
        pool = ProcessPool(num_workers)
    elif num_workers == 0:
        pool = SerialPool()
    else:
        raise ValueError("`num_workers` must be non-negative")
    
    imap = pool.imap(cache_instance, range(len(dataset)))
    
    result = [x for x in tqdm(imap, total=len(dataset))]
    
    return result


@dataclass
class BinaryLoader:
    """
    Class to load binary. Built-in caching (default 100 instances)
    """
    data_dict: Dict[tuple, tuple] = field(default_factory = dict)
    data_queue: List[int] = field(default_factory=list)
    cache_size: int = loader_cache_size

    def __call__(self, path: str, metadata_path: Optional[str] = None) -> Tuple[bytes, Metadata]:
        """Loads a binary from disk

        Args:
            path: Path to the binary
            metadata_path: Path to the metadata

        Returns:
            A binary and its associated metadata
        """
        key = (path, metadata_path)
        # Cache the value if not cached before
        if key not in self.data_dict:
            with open(path, "rb") as f:
                binary = f.read()
            if metadata_path:
                metadata = torch.load(metadata_path)
            else:
                metadata = Metadata()
            metadata['binary_path'] = path
            metadata.set_updater('binary_path', BinaryPathUpdater())
            value = (binary, metadata)
            # Directly return if we dont cahce
            if self.cache_size <= 0:
                return value
            # add to cache (deepcopy because it is mutable)
            self.data_dict[key] = value
            self.data_queue.append(key)

            # pop elements if it exceeds the maximum cache size
            if len(self.data_queue) > self.cache_size:
                old_key = self.data_queue.pop(0)
                del self.data_dict[old_key]
        # Return the value
        value = self.data_dict[key]
        return value[0], deepcopy(value[1])

