import re
from torch.utils.data import Dataset
import pandas as pd


def flatten_dict(d, keys_to_flatten, parent_key='', sep='.'): 
    """Recursively flattens a nested dictionary and returns a list of flattened keys."""
    items = []
    for k, v in d.items():
        new_key = f"{k}" if parent_key else k
        if isinstance(v, dict) and new_key in keys_to_flatten:
            items.extend(flatten_dict(v, keys_to_flatten, new_key, sep=sep).items())
        else:
            if not isinstance(v, dict):
                items.append((new_key, v))
    return dict(items)


def extract_dataset_name_number(string):
    name = re.search(r'^(\w+)\[', string).group(1)
    inside = re.search(r'\[(.*?)\]', string).group(1)
    if inside == "all":
        return name, [], []
    # Find all ranges and individual numbers
    ranges = re.findall(r'(\d+:\d+)', inside)
    numbers = re.findall(r'(?<!:)(?<!\d)(\d+)(?!:)', inside)
    numbers_int = [int(num) for num in numbers]
    ranges_int = [(int(text.split(":")[0]), int(text.split(":")[1])) for text in ranges]
    match = re.match(r"([a-zA-Z]+)\[(\d+):(\d+)\]", string)

    if match:
        name = match.group(1)  # "classical"
        start = int(match.group(2))  # 0
        end = int(match.group(3))    # 1
    #return name, start, end
    return name, numbers_int, ranges_int

def create_dataset_list(datagenerator, dataset_list_AD, dataset_list_string):
    dataset_list_complete = []
    if dataset_list_AD:
        for dataset_string in dataset_list_AD:
            name, numbers_int, ranges_int = extract_dataset_name_number(dataset_string)
            dataset_list = getattr(datagenerator, f"dataset_list_{name}")
            for number in numbers_int:
                dataset_list_complete.extend(dataset_list[number])
            for range in ranges_int:
                dataset_list_complete.extend(dataset_list[range[0]: range[1]])
            print(numbers_int)
            print(ranges_int)
            if len(numbers_int) == 0 and len(ranges_int) == 0:
                for dataset in dataset_list:
                    print(dataset)
                    dataset_list_complete.append(dataset)

    if dataset_list_string:
        for dataset_string in dataset_list_string:
            dataset_list_complete.append(dataset_string)
    return dataset_list_complete

def get_model_backbone_from_csv(path, id):
    df = pd.read_csv(path, sep=";")
    filtered_df = df[df["ID"] == id]
    train_type = filtered_df["Train_Type"].iloc[0]
    backbone = filtered_df["Backbone_Model"].iloc[0]
    return train_type, backbone

class CustomDataset(Dataset):
    def __init__(self, tensor_data):
        self.tensor_data = tensor_data

    def __len__(self):
        return len(self.tensor_data)

    def __getitem__(self, idx):
        return self.tensor_data[idx]