from pathlib import Path
import random
import os
from typing import Any, Union
import numpy as np
import pickle
import torch
from tabulate import tabulate
from omegaconf import DictConfig, OmegaConf
from scipy.signal import medfilt
from torch_geometric.data import InMemoryDataset


def seed_everything(seed):
    """Seed all possible random generators to ensure reproducible results."""
    random.seed(seed)  # Seed the built-in Python random module
    np.random.seed(seed)  # Seed NumPy's random generator
    torch.manual_seed(seed)  # Seed PyTorch (CPU)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)  # Seed all GPUs for PyTorch

    torch.backends.cudnn.deterministic = True  # Make CuDNN deterministic
    torch.backends.cudnn.benchmark = False  # Disable CuDNN benchmarking

    os.environ["PYTHONHASHSEED"] = str(seed)  # Seed Python's hash function for strings and other objects


def flatten_config(config, parent_key="", separator="_"):
    """
    Flatten an OmegaConf DictConfig into a single level dictionary.
    Keys are joined with parent keys using the specified separator.

    Args:
        config (DictConfig): The nested configuration
        parent_key (str): The parent key used in recursion
        separator (str): The separator to use between nested keys

    Returns:
        dict: Flattened configuration dictionary
    """
    flattened = {}

    for key, value in config.items():
        new_key = f"{parent_key}{separator}{key}" if parent_key else key

        if isinstance(value, DictConfig):
            # Recursively flatten child DictConfig and update the flattened dict
            flattened.update(flatten_config(value, new_key, separator))
        else:
            # Convert OmegaConf structured types to Python native types if needed
            if isinstance(value, (list, tuple)):
                value = OmegaConf.to_container(value)
            flattened[new_key] = value

    return flattened


def print_flattened_config(cfg):

    # Flatten the configuration
    flat_config = flatten_config(cfg, separator=".")

    # Convert the flat dictionary to a list of tuples for tabulation
    config_table = sorted(flat_config.items())  # Sort for consistent output
    # Print the configuration settings in a table format
    print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="fancy_outline"))


def best_epoch_select_strategy(perf, val_name, metric, strategy="ma", agg=3):
    # Define a dictionary that specifies whether to minimize or maximize each metric
    # agg for ma - window size, mf - kernel size, ema - alpha
    metric_preferences = {
        "average_loss": "min",
        "accuracy": "max",
        "f1_score": "max",
        "error_rate": "min",
        "auc_roc": "max",
        # Add other metrics as needed
    }

    # Default to minimizing if the metric isn't specified in the dictionary
    minimize_metric = metric_preferences.get(metric, "min") == "min"
    metric = f"{val_name}/{metric}"  # Add the logger name to the metric name
    if metric == "auto":
        metric = "average_loss"  # Defaulting to 'loss' if 'auto' is provided
        metric_values = np.array([vp[metric] for vp in perf])
    else:
        metric_values = np.array([vp[metric] for vp in perf])

    # Check if there are enough data points to compute the filtering
    if len(metric_values) < agg:
        # Return the best index based on min or max if data is too short for filtering
        best_index = np.argmin(metric_values) if minimize_metric else np.argmax(metric_values)
        return best_index

    # Apply selected strategy
    if strategy == "ma":  # Moving Average
        weights = np.ones(agg) / agg
        smoothed_metrics = np.convolve(metric_values, weights, mode="valid")
        best_epoch = np.argmin(smoothed_metrics) if minimize_metric else np.argmax(smoothed_metrics)
    elif strategy == "mf":  # Median Filter
        filtered_metrics = medfilt(metric_values, kernel_size=agg)
        best_epoch = np.argmin(filtered_metrics) if minimize_metric else np.argmax(filtered_metrics)
    elif strategy == "ema":  # Exponential Moving Average
        alpha = agg * 0.1  # Use 'agg' to influence alpha, ensuring it's between 0 and 1

        def exponential_moving_average(values, alpha):
            ema = [values[0]]
            for v in values[1:]:
                ema.append(alpha * v + (1 - alpha) * ema[-1])
            return np.array(ema)

        ema_metrics = exponential_moving_average(metric_values, alpha)
        best_epoch = np.argmin(ema_metrics) if minimize_metric else np.argmax(ema_metrics)

    elif strategy == "best":  # Best overall
        best_epoch = np.argmin(metric_values) if minimize_metric else np.argmax(metric_values)

    # Adjust for window offset in MA strategy
    if strategy == "ma" and len(smoothed_metrics) > 0:
        best_epoch += agg // 2
    # Ensure the best epoch is within the bounds of the actual data length
    return min(best_epoch, len(perf) - 1)


def metric_from_dataset_name(dataset_name):
    dataset_metrics_dict = {
        "WikipediaNetwork-squirrel": "accuracy",
        "WebKB-texas": "accuracy",
        "WebKB-wisconsin": "accuracy",
        "Actor": "accuracy",
        "WikipediaNetwork-chameleon": "accuracy",
        "HeterophilousGraphDataset-Roman-Empire": "accuracy",
        "HeterophilousGraphDataset-Amazon-ratings": "accuracy",
        "HeterophilousGraphDataset-Minesweeper": "auc_roc",
        "HeterophilousGraphDataset-Tolokers": "auc_roc",
        "HeterophilousGraphDataset-Questions": "auc_roc",
        "Amazon-photo": "accuracy",
        "Amazon-computers": "accuracy",
        "Coauthor-CS": "accuracy",
        "Coauthor-Physics": "accuracy",
        "ogbn-arxiv": "accuracy",
        "WikiCS": "accuracy",
        "CityNetwork-shanghai": "accuracy",
        "CityNetwork-paris": "accuracy",
    }
    return dataset_metrics_dict[dataset_name]


def num_tuning_trials_from_dataset_name(dataset_name):
    d = {
        "WebKB-texas": 300,  # N = 183
        "WebKB-wisconsin": 300,  # N = 251
        "WikipediaNetwork-chameleon": 300,  # N = 2,277
        "WikipediaNetwork-squirrel": 300,  # N = 5,201
        "Actor": 200,  # N = 7,600
        "HeterophilousGraphDataset-Minesweeper": 200,  # N = 10,000
        "HeterophilousGraphDataset-Tolokers": 200,  # N = 11,758
        "HeterophilousGraphDataset-Roman-Empire": 100,  # N = 22,662
        "HeterophilousGraphDataset-Amazon-ratings": 100,  # N = 24,492
        "HeterophilousGraphDataset-Questions": 100,  # N = 48,921
        "Amazon-photo": 200,  # N = 7,650
        "Amazon-computers": 200,  # N = 13,752
        "WikiCS": 200,  # N = 11,701
        "Coauthor-CS": 100,  # N = 18,333
        "Coauthor-Physics": 100,  # N = 34,493
        "ogbn-arxiv": 100,  # N = 169,343, 1 split
    }
    return d[dataset_name]


def split_indices_from_dataset_name(dataset_name):
    # ogbn-arxiv has 1 split, the rest have 10
    if dataset_name == "ogbn-arxiv":
        return [0]
    elif dataset_name == "WikiCS":
        return list(range(20))
    else:
        return list(range(10))


def get_class_from_path(path: str):
    import importlib

    module_name, class_name = path.rsplit(".", 1)
    module = importlib.import_module(module_name)
    cls = getattr(module, class_name)
    return cls


def move_to_device(data: Any, device: torch.device):
    if isinstance(data, torch.Tensor):
        return data.to(device)
    elif isinstance(data, dict):
        return {k: move_to_device(v, device) for k, v in data.items()}
    elif isinstance(data, list):
        return [move_to_device(elem, device) for elem in data]
    else:
        return data


def get_dataset(
    data_root: Union[str, Path],
    dataset_cfg: DictConfig,
    eigen_subset_cfg: DictConfig,
    split_index: int,
) -> InMemoryDataset:

    # load dataset
    ds_path = Path(data_root) / "processed" / f"{dataset_cfg.shortname}.pt"
    with open(ds_path, "rb") as f:
        dataset: InMemoryDataset = pickle.load(f)
        print(f"Preprocessed dataset loaded from {ds_path}")

    # Choose split
    # masks have default shape: [n_nodes, splits]
    dataset._data.train_mask = dataset._data.train_mask[:, split_index]
    dataset._data.val_mask = dataset._data.val_mask[:, split_index]
    dataset._data.test_mask = dataset._data.test_mask[:, split_index]

    # Subset eigenvectors
    eigen_subset_fn = get_class_from_path(eigen_subset_cfg.fn)
    eigen_subset_fn(dataset=dataset, **eigen_subset_cfg.args)

    assert torch.isfinite(dataset._data.eigenvecs).all()
    assert torch.isfinite(dataset._data.eigenvals).all()
    return dataset


def generate_random_string(length=10):
    import string

    characters = string.ascii_letters + string.digits
    random_string = "".join(random.choice(characters) for _ in range(length))
    return random_string
