# 1. Support Multitask Running (Gpu allocation)
# 2. Support Task Register (Give function(dill), Give args(dict), function receives gpu&path&i&j&config)
# 3. Support Task Logging (Log params for a function call)

import os
import re
import time
import json
import dill
import random
import base64
import traceback
import tracemalloc
import collections
import subprocess as sp
import multiprocessing as mp


dill.settings['recurse'] = True

class HookedOutput():
    def __init__(self, filename, original) -> None:
        self.file = open(filename, 'w')
        self.original = original

    def write(self, data):
        self.original.write(data)
        self.file.write(data)

    def flush(self):
        self.original.flush()
        self.file.flush()

def dumps(obj):
    return base64.b64encode(dill.dumps(obj)).decode()
def loads(str):
    return dill.loads(base64.b64decode(str.encode()))
def json_serializable(obj):
    try:
        json.dumps(obj)
        return True
    except TypeError:
        return False

def subprocess(target, args, gpu, path, taskid, repeatid):
    """This method will be the target of multiprocessing and started as a new process."""
    # 1: Set Gpu run environment:
    import torch
    torch.cuda.set_device("cuda:"+str(gpu))
    torch.set_default_tensor_type(torch.cuda.FloatTensor)

    # 2: Hook Stdout & Stderr for logging
    import os
    import sys
    sys.stdout = HookedOutput(os.path.join(path, "log.txt"), sys.stdout)
    sys.stderr = HookedOutput(os.path.join(path, "err.txt"), sys.stderr)

    # 3: Load, Setup stat
    func = loads(target)
    args = loads(args)
    exception = None
    start_time = {"Wall": time.time(), "User": time.process_time()}
    tracemalloc.start()

    # 4: Run
    try:
        func(gpu=gpu, path=path, taskid=taskid, repeatid=repeatid, **args)
    except Exception as e:
        traceback.print_exc(file=sys.stderr)
        exception = e

    # 5: Save Summary
    _, mem_peak = tracemalloc.get_traced_memory()
    mem_snapshot = tracemalloc.take_snapshot()
    tracemalloc.stop()
    with open(os.path.join(path, "summary.json"), "w") as f:
        json.dump({
            "Success": True if exception is None else False,
            "Wall Time": time.time() - start_time["Wall"],
            "User Time": time.process_time() - start_time["User"],
            "Peak Memory Usage": mem_peak,
            "Top 10 Memory Used": list(map(str, mem_snapshot.statistics('lineno')[:10]))
        }, f, indent=4)


class Task:
    """
    Register a Task
    To initialize: give an dict containing "target", "args" and "repeat"
    Target should accept (gpu, path, taskid, repeatid) four args
    Args is strongly recommended to be json serializable (string, numbers, dict, list, ...)
    """

    def __init__(self, target, args={}, repeat=1):
        if not isinstance(args, dict):
            raise ValueError("Function Args is not a dict")

        self.target = target
        self.args = args
        self.repeat = repeat
        self.progress = None
    
    def copy(self):
        return Task(loads(dumps((self.target))), self.args, self.repeat)

    def update(self, renew_args):
        return Task(loads(dumps((self.target))), {**self.args, **renew_args}, self.repeat)

    def dict(self):
        if json_serializable(self.args):
            return {"Target Name":self.target.__name__, "Args": self.args, "Repeat": self.repeat, "Target": dumps(self.target)}
        else:
            return {"Target Name":self.target.__name__, "Args Name": repr(self.args), \
                    "Repeat": self.repeat, "Target": dumps(self.target), "Args": dumps(self.args)}
        

    def save(self, filepath, additional={}):
        with open(filepath, "w") as f:
            json.dump({**additional, **self.dict()}, f, indent=4)
    
    @staticmethod
    def load(task):
        args = loads(task['Args']) if isinstance(task['Args'], str) else task['Args']
        return Task(loads(task['Target']), args, task['Repeat'])
    
    def run(self, gpu, path, taskid, repeatid):
        self.process = mp.Process(target=subprocess, args=(dumps(self.target), dumps(self.args), gpu, path, taskid, repeatid))
        self.process.start()
    
    def alive(self):
        return self.process.is_alive() if self.process else False
    
    def join(self):
        return self.process.join()

class Server: # Gpu program Run Server
    def __init__(self, config_file=None, exp_name=None, path=None, device=None, repeat=None, hostname=None):
        self.config = self.load_config(config_file) if config_file else {}
        self.exp_name = exp_name if exp_name is not None else self.config.get("Experiment Name", "Test")
        self.path = path if path is not None else self.config.get("Path", os.path.join('runs', time.strftime('%m.%d-%H.%M.%S') + self.exp_name))
        if not os.path.exists(self.path):
            os.mkdir(self.path)
        self.device = device if device is not None else self.config.get("Device", "Auto")
        if isinstance(self.device, str):
            self.device = "Auto" if self.device == "Auto" else list(map(int, self.device.split(",")))
        elif isinstance(self.device, (list, tuple)):
            self.device = list(map(int, self.device))
        else:
            raise ValueError(f"Unknown device format: {self.device}")
        self.repeat = repeat if repeat is not None else self.config.get("Repeat", 1)
        self.hostname = hostname if hostname is not None else sp.check_output(['hostname']).decode().strip()

        self.default_task = None
        self.dependency = False
        self.tasks = []            # List of tasks
        self.pool = {}             # Current Running Process

    def load_config(self, config_file):
        try:
            with open(config_file, 'r') as f:
                config = json.load(f)
            return config
        except FileNotFoundError:
            print(f"Config file {config_file} not found.")
            return {}
        except json.JSONDecodeError as e:
            print(f"Error parsing JSON config: {e}")
            return {}

    def allocate_gpu(self):
        def available(gpu):
            task = self.pool.get(gpu)
            return task is None or not task.alive()
        
        if self.device == "Auto":  # Auto Gpu Allocation Based on gpustat
            gpustat = json.loads(sp.check_output(['gpustat', '--json']).decode())["gpus"]
            for gpu in gpustat:
                if available(gpu['index']) and not gpu['processes']:
                    return gpu['index']
        else:                      # Traditional Gpu Allocation (run even gpu is busy)
            for gpu in self.device:
                if available(gpu):
                    return gpu
        
        return None

    def wait_and_get_gpu(self):
        gpu = self.allocate_gpu()
        while gpu is None:
            time.sleep(10)
            gpu = self.allocate_gpu()
        return gpu

    def dict(self):
        return {
            "Experiment Name": self.exp_name,
            "Devices": self.device,
            "Path": self.path,
            "NumTask": len(self.tasks),
            "Repeat": self.repeat,
            "Dependency": self.dependency,
            "Tasks": list(map((lambda t:t.dict()), self.tasks))
        }
    
    def save(self, filepath):
        with open(filepath, "w") as f:
            json.dump(self.dict(), f, indent=4)
    
    def set_default_task(self, task):
        if isinstance(task, Task):
            self.default_task = task
        elif isinstance(task, dict):
            self.default_task = Task(**{"repeat":self.repeat, **task})
        else:
            raise ValueError(f"Unknown Type {type(task)}")

    def add_task(self, task):
        if isinstance(task, Task):
            self.tasks.append(task)
        elif isinstance(task, dict):
            if self.default_task:
                self.tasks.append(self.default_task.update(task))
            else:
                self.tasks.append(Task(**{"repeat":self.repeat, **task}))
        else:
            raise ValueError(f"Unknown Type {type(task)}")
    
    def load_task(self, taskfile):
        with open(taskfile, "r") as f:
            config = json.load(f)
        
        self.tasks=[]
        for task in config["Tasks"]:
            self.tasks.append(Task.load(task))

    def set_dependency(self, dependency):
        """
        Set The dependency for each task. 
        None&False for no dependency / True for dependency on previous task / int for specific repeat. List is available.
        """
        self.dependency = dependency

    def get_training_state(self): # Get the times each task runned
        claimed_tasks = []
        finished_tasks = []
        for entry in os.scandir(self.path):
            mobj = re.match(r"^(\d+)-(\d+)$", entry.name)
            if mobj:
                claimed_tasks.append((int(mobj[1]), int(mobj[2])))
                if os.path.exists(os.path.join(entry.path, "summary.json")):
                    finished_tasks.append((int(mobj[1]), int(mobj[2])))
        
        next_repeat = collections.defaultdict(int)
        for i,j in sorted(claimed_tasks):
            if next_repeat[i] == j:
                next_repeat[i] += 1
        return next_repeat, claimed_tasks, finished_tasks

    def get_next_task(self):
        """
        Select the task with minimum repeat times & minimum id.
        If self.dependency is set, it will apply according to the dependency setting.
        """
        taskid, minrepeat = None, 1e5
        next_repeat, claimed_tasks, finished_tasks = self.get_training_state()
        for i, task in enumerate(self.tasks):
            if i == 0 or self.dependency[i] in (None, False):
                depend = None 
            elif isinstance(self.dependency[i], int):
                depend = self.dependency[i] 
            else:
                depend = next_repeat[i] % self.tasks[i-1].repeat
            
            if next_repeat[i] < task.repeat and minrepeat > next_repeat[i] and (not depend or (i-1, depend) in finished_tasks):
                minrepeat = next_repeat[i]
                taskid = i
        
        return taskid, minrepeat

    def finished(self):
        next_repeat, *_ = self.get_training_state()
        for i, task in enumerate(self.tasks):
            if next_repeat[i] < task.repeat:
                return False
        return True

    def run(self):
        mp.set_start_method("spawn")
        if not os.path.exists(os.path.join(self.path, "tasks.json")):
            self.save(os.path.join(self.path, "tasks.json"))
        else:
            print(f"\033[33m############  Server Warning: Load Task List from existing file. Current Task List replaced  ############\033[0m")
            self.load_task(os.path.join(self.path, "tasks.json"))

        if isinstance(self.dependency, bool):                 # Extend Dependency info after loading task
            self.dependency = [self.dependency] * len(self.tasks)
        
        while not self.finished():                            # always try to find the next task
            gpu = self.wait_and_get_gpu()                     # first get a gpu
            i, j = self.get_next_task()
            if i is None: break                               # task already finished while waiting for gpu

            path = os.path.join(self.path, f"{i}-{j}") + os.path.sep
            try:
                os.mkdir(path)                                # try to occupy the path
            except FileExistsError:
                time.sleep(3 + random.randint(1, 5))          # failed to occupy, try again later
                continue

            task = self.tasks[i].copy()
            self.pool[gpu] = task
            task.save(os.path.join(path, "task.json"), additional={"Hostname": self.hostname, "GPU": gpu, "TaskID":i, "RepeatID": j})
            print(f"\033[32m############  Task {i}-{j} Start On GPU {gpu} ############\033[0m")
            task.run(gpu, path, taskid=i, repeatid=j)
        
        for key, task in self.pool.items():
            task.join()