import argparse
import random
import torch
import yaml
from types import SimpleNamespace
import pynvml
import time
from typing import Union
class Config(SimpleNamespace):
    """A class that inherits from SimpleNamespace to allow for attribute access of dictionary keys."""
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def __getitem__(self, key):
        return getattr(self, key)
        
    def __setitem__(self, key, value):
        setattr(self, key, value)
        
    def __str__(self):
        return str(self.get_dict())
        
    def print(self):
        print_config(self)
        
    def get_dict(self):
        return Config2Dict(self)


class Options:
    """The base options will be configured in config.yaml, and the options defined in this file will 
    override the options in config.yaml.
    """
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        self.parser.add_argument("--config", type=str, default="", help="path to config file, separated by comma, e.g. 'config/base.yaml,config/adapt.yaml'")
        self.parser.add_argument("--seed", type=int, default=1, help='static seed, -1 for random seed')
        self.parser.add_argument("--gpu_ids", type=str, default="0,1,2,3,4,5,6,7", help='gpu ids: e.g. 0  0,1,2, 0,2.')
        self.parser.add_argument("--gpu", type=int, default=-2, help='gpu id: e.g. 0 1 2. use -1 for CPU, -2 for auto-select')
        
        self.parser.add_argument("--data_name", type=str, default="none", help='data name', choices=["rotate_mnist", "color_mnist", "portraits", "covertype", "cifar10", "cifar100", "imagenet"])
        self.parser.add_argument("--model_name", type=str, default="none", help='model name') # cnn, resnet, vgg
        self.parser.add_argument("--method_name", type=str, default="none", help='method name') # gst, goat, gdo, gas, gmma
        self.parser.add_argument("--domain_num", type=int, default=4, help='domain number', choices=[2, 3, 4, 5, 6])
        
        self.parser.add_argument("--corruption", type=str, default="none", help='corruption name')

    def parse(self, config_path: list[str] = []) -> Config:
        opt = self.parser.parse_args()
        if opt.config:
            config_path += opt.config.replace(' ', '').split(',')
        config = get_config(config_path)
              
        # time
        config.time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
              
        # set seed
        config.seed = opt.seed if opt.seed != -1 else random.randint(1, 10000)
        
        # set gpu ids
        config.gpu_ids = list(map(int, opt.gpu_ids.split(','))) if ',' in opt.gpu_ids else [int(opt.gpu_ids)]
        # assert max(config.gpu_ids) <= torch.cuda.device_count(), f"Invalid GPU IDs {config.gpu_ids} (max: {torch.cuda.device_count()})"
        if opt.gpu == -2:
            try:
                pynvml.nvmlInit()
                free_memory = []
                for i in range(torch.cuda.device_count()):
                    try:
                        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
                        meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
                        free_memory.append(meminfo.free)
                        # print(f"GPU {i}: Free Memory = {meminfo.free / (1024**3):.2f} GB, Total Memory = {meminfo.total / (1024**3):.2f} GB")
                    except:
                        free_memory.append(0)
                pynvml.nvmlShutdown()
                if free_memory:
                    opt.gpu = free_memory.index(max(free_memory))
                    # print(f"using GPU {opt.gpu} with the most free memory {max(free_memory)}")
                else:
                    opt.gpu = 0
                    torch.cuda.set_device(opt.gpu)
                    print("there is no GPU available, using default GPU")
            except:
                opt.gpu = -1
                print("pynvml not available, using default GPU")
        config.device = torch.device(f"cuda:{opt.gpu}" if torch.cuda.is_available() and opt.gpu >=0 else "cpu")
        
        # other
        config.data_name = opt.data_name
        config.model_name = opt.model_name
        config.method_name = opt.method_name
        config.domain_num = opt.domain_num
        config.corruption = opt.corruption
        # print_config(config)
        return config


def get_config(config_path: Union[str, list[str]] = []):
    if isinstance(config_path, str):
        config_path = [config_path]
    cfg = load_config(["config/base.yaml"] + config_path)
    cfg.config_path = config_path
    return cfg
    
def load_config(config_path: list[str]):
    # print(f"loading config from {config_path}")
    config = Config()
    for path in config_path:
        with open(path, 'r') as f:
            yaml_config = yaml.safe_load(f)
        config = Dict2Config(yaml_config, config)
    return config

def Dict2Config(d: dict, cf: Config=None):
    cf = Config() if cf==None else cf
    if d == None:
        return cf
    for k, v in d.items():
        if isinstance(v, dict):
            if hasattr(cf, k):
                setattr(cf, k, Dict2Config(v, getattr(cf, k)))
            else:
                setattr(cf, k, Dict2Config(v))
        else:
            if hasattr(cf, k):
                print(f"[Warn] '{k}' already exists in config, will be overwritten by '{v}'")
            setattr(cf, k, v)
    return cf

def print_config(config: Config):
    def _print_config(config, depth=0):
        for k, v in sorted(vars(config).items()):
            if isinstance(v, Config):
                print(f"{'  ' * depth}{k}:")
                _print_config(v, depth + 1)
            else:
                print(f"{'  ' * depth}{k}: {v}")
    print("------------ Options ------------")
    _print_config(config)
    print("-------------- End --------------")
    
def Config2Dict(config: Config):
    return {k: Config2Dict(v) if isinstance(v, Config) else v for k, v in sorted(vars(config).items())}
    
if __name__ == "__main__":
    opt = Options()
    config = opt.parse()
    print(config)
    print_config(config)
    print(Config2Dict(config))
