import torch,os,time,random,hashlib,copy,json
from pathlib import Path
from joblib import Parallel, delayed
from dataclasses import dataclass,asdict,field,is_dataclass
from collections import deque


@dataclass
class Config():
    save_folder:Path
    random_seed="test"
    def run():
        pass
    def get_abs():
        pass

@dataclass
class Variables():
    def set(self,conf):
        return copy.deepcopy(conf)
    def get_abs():
        pass
    
class ConfigNode:
    def __init__(self, config:Config|Variables,parent=None):
        self.config = config
        self.parent=parent
        self.parent_folder=parent.get_folder()
        self.children = []
        
    def add_child(self, child_node):
        self.children.append(child_node)

    def add_child_config(self,child_config):
        self.add_child(
            ConfigNode(child_config,self)
        )

    def __repr__(self):
        return f"TreeNode({self.value})"
    
    def get_folder(self):
        return os.path.join(self.parent_folder,self.config.get_abs())
        # d = deque()
        # node=self
        # while node is not None:
        #     d.appendleft(self.value.get_abs())
        #     node=node.parent
        # return os.path.join(*d)

    def get_full_config(self):
        if self.parent is not None:
            parent_full_config=self.parent.get_full_config()
            full_config=self.config.set(parent_full_config)
            return full_config
        else:
            return self.config
    
    def bfs(self):
        nodes=[]
        queue = deque([self])
        while queue:
            node = queue.popleft()
            nodes.append(node)
            for child in node.children:
                queue.append(child)

        return queue

    def is_leaf(self):
        if len(self.children)==0:
            return False
        else:
            return True

            

def run_with_oom_wait(func, max_retry=12, base_sec=2, **kwargs):
    """指数退避直到显存足够"""
    for attempt in range(max_retry):
        try:
            return func(**kwargs)
        except RuntimeError as e:
            if 'out of memory' in str(e):
                torch.cuda.empty_cache()
                sleep = base_sec * 2 ** attempt + random.uniform(0, 1)
                print(f"[{os.getpid()}] OOM, sleep {sleep:.1f}s")
                time.sleep(sleep)
            else:
                raise
    raise RuntimeError("Still OOM after retries")



def conduct_rl_mnist_parallel(configs:list[Config],n_jobs=20,**kwargs):
    print("item nums",len(configs))
    def run(config,i):
        print(i)
        run_with_oom_wait(config.run,**kwargs)

    Parallel(n_jobs=n_jobs)(
        delayed(run)(config,i) for i,config in enumerate(configs)
        )
    

  
def create_random_seed(n):
    return [f"test{i}" for i in range(n)]

def class_name_only(obj):
    return obj.__class__.__name__

def get_string_hash(s:str):
    h = hashlib.md5(s.encode()).hexdigest()[:6]
    return h


class RepeatedExperiments():
    def __init__(self):
        pass

    def get_otherhash(self,c):
        c_string='-'.join(map(str, c))
        otherhash=get_string_hash(c_string)
        return otherhash
    
    @staticmethod
    def save_config(config_folder,c):
        os.makedirs(config_folder, exist_ok=True)
        with open(os.path.join(config_folder, 'config.json'), 'w', encoding='utf-8') as f:
            json.dump(asdict(c),f,indent=2,default=class_name_only)  

    def get_save_folder(self,p_folder,item,random_seed):
        return os.path.join(p_folder,f"{item}",random_seed)
    
    def run(self,nodes:list[ConfigNode],repeat_num,**kwargs):
        configs=[]
        seeds=create_random_seed(repeat_num)
        for node in nodes:
            self.save_config(node.get_folder(),node.config)
            if nodes.is_leaf():
                for seed in seeds:
                    cur_conf=node.get_full_config()
                    cur_conf.save_folder=node.get_folder()
                    cur_conf.random_seed=seed
                    configs.append(cur_conf)

        conduct_rl_mnist_parallel(configs,n_jobs=20,**kwargs)

        

    # def run(self,common_conf_folder,common_conf,repeat_num,**kwargs):
    #     self.save_config(common_conf_folder,common_conf)
    #     seeds=create_random_seed(repeat_num)
    #     configs=[]
    #     for item in self.items:
    #         item_conf_folder=os.path.join(common_conf_folder,f"{item}")
    #         self.save_config(item_conf_folder,item)
    #         for seed in seeds:
    #             cur_folder=os.path.join(item_conf_folder,f"{seed}")
    #             cur_conf=item.set(common_conf)
    #             cur_conf:Config
    #             cur_conf.save_folder=cur_folder
    #             cur_conf.random_seed=seed
    #             configs.append(cur_conf)

    #     conduct_rl_mnist_parallel(configs,n_jobs=20,**kwargs)





# @util.repr_alias(attr_name=False)   
# @dataclass
# class OneExperiment():
#     conf:Config
#     save_folder:Path

# class Variable():
#     def __init__(self,name,set_func):
#         self.name=name
#         self.set_func=set_func
