import os

import random
import numpy as np
import json
import torch
from sklearn.metrics import roc_auc_score, average_precision_score

def set_seed(seed=42):
    global _GLOBAL_SEED
    _GLOBAL_SEED = seed
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_global_seed():
    return _GLOBAL_SEED

def build_save_path(model_name, dataset_name, seed, base_dir= "./analysis/models/"):
    folder_path = os.path.join(base_dir, model_name)
    os.makedirs(folder_path, exist_ok=True)
    return f"{folder_path}/{model_name}_{dataset_name}_{seed}.pt"

def save_model(model, path):
    torch.save(model.state_dict(), path)

def load_model(model, path):
    model.load_state_dict(torch.load(path, map_location='cpu'))

def cal_metric(y_true, y_score, pos_label=1):
    aucroc = roc_auc_score(y_true=y_true, y_score=y_score)
    aucpr = average_precision_score(y_true=y_true, y_score=y_score, pos_label=pos_label)

    return {'aucroc':aucroc, 'aucpr':aucpr}

class SweepConfig:
    """
    A helper class to load and provide sweep configurations for different models
    from a JSON file (sweep_space.json).
    """
    def __init__(self, path: str):
        """
        :return: A list of all model names defined in the JSON.
        """
        with open(path, 'r') as f:
            self._cfg = json.load(f)

    def models(self):
        return list(self._cfg.keys())

    def get_config(self, model_name: str):
        """
        Retrieve the sweep configuration for a specific model.

        :param model_name: The name of the model to fetch.
        :return: A dict containing the sweep settings for that model.
        :raises KeyError: If the specified model_name is not found in the JSON.
        """
        if model_name not in self._cfg:
            raise KeyError(f"'{model_name}' is not defined in the sweep configuration.")
        return self._cfg[model_name]

    def items(self):
        """
        :return: An iterator over (model_name, config_dict) pairs for all defined models.
        """
        return self._cfg.items()
    
class ModelConfig:
    def __init__(self, model_name: str):
        config_path = f'./methods/configs/{model_name}.json'

        if not os.path.isfile(config_path):
            raise FileNotFoundError(f"Config file not found: {config_path}")
        with open(config_path, 'r', encoding='utf-8') as f:
            self.config = json.load(f)

    def get_param(self, key: str):
        return self.config.get(key)
    
    def resolve(self, dataset_name):
        config_dataloader = self.config[dataset_name]['dataloader']
        config_model = self.config[dataset_name]['model']
        config_train = self.config[dataset_name]['train']
        return config_dataloader, config_model, config_train

if __name__ == "__main__":
    sc = SweepConfig("./methods/configs/sweep_space_(usad,dagmm).json")
    sc.get_config("usad")['parameters']
    for model_name, cfg in sc.items():
        print(f"Model: {model_name}, Config: {cfg}")

    cfg = ModelConfig('OmniAnomaly')
    cfg.get_param('model')['use_PNF']