import os
import h5py
import random
import numpy as np
import pandas as pd
from typing import Type
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split, Subset
import utils as utils

class ptbxldataset(Dataset):
    def __init__(self, data_path, label_path):
        self.data = np.load(data_path, allow_pickle=True) 
        self.labels = np.load(label_path, allow_pickle=True)
        self.data = self.data.astype(np.float32)
        self.labels = self.labels.astype(np.float32)
        assert self.data.shape[0] == self.labels.shape[0]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample_data = torch.tensor(self.data[idx], dtype=torch.float32)
        sample_label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return sample_data, sample_label

class g12ecdataset(Dataset):
    def __init__(self, data_path, label_path):
        self.data = np.load(data_path, allow_pickle=True) 
        self.labels = np.load(label_path, allow_pickle=True)
        self.data = self.data.astype(np.float32)
        self.labels = self.labels.astype(np.float32)
        assert self.data.shape[0] == self.labels.shape[0]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample_data = torch.tensor(self.data[idx], dtype=torch.float32)
        sample_label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return sample_data, sample_label
    

class mimic_ecgdataset(Dataset):
    def __init__(self, data_path, data_size):
        self.data = np.memmap(data_path, dtype='float32', mode='r', shape = (data_size, 5000, 12))
        self.data = np.array(self.data)
        # self.data = np.load(data_path, allow_pickle = True)
        # self.data = self.data.astype(np.float32)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample_data = torch.tensor(self.data[idx], dtype=torch.float32)
        return sample_data

class ptbxldataset_MC(Dataset):
    def __init__(self, data_path, label_path, transform):
        self.data = np.load(data_path, allow_pickle=True) 
        self.labels = np.load(label_path, allow_pickle=True)

        self.data = self.data.astype(np.float32)
        self.labels = self.labels.astype(np.float32)

        self.data = np.transpose(self.data, (0, 2, 1))
        self.transform = transform
        assert self.data.shape[0] == self.labels.shape[0]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data, labels = self.data[index], self.labels[index]

        sample = {"data": data, "label": labels}
        sample = self.transform(sample)
        sample_data, sample_label = sample["data"], sample["label"]

        # sample_data = torch.tensor(sample["data"], dtype=torch.float32)
        # sample_label = torch.tensor(sample["label"], dtype=torch.float32)
        return sample_data, sample_label
    
class g12ecdataset_MC(Dataset):
    def __init__(self, data_path, label_path, transform):
        self.data = np.load(data_path, allow_pickle=True) 
        self.labels = np.load(label_path, allow_pickle=True)

        self.data = self.data.astype(np.float32)
        self.labels = self.labels.astype(np.float32)

        self.data = np.transpose(self.data, (0, 2, 1))
        self.transform = transform
        assert self.data.shape[0] == self.labels.shape[0]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data, labels = self.data[index], self.labels[index]

        sample = {"data": data, "label": labels}
        sample = self.transform(sample)
        sample_data, sample_label = sample["data"], sample["label"]

        # sample_data = torch.tensor(sample["data"], dtype=torch.float32)
        # sample_label = torch.tensor(sample["label"], dtype=torch.float32)
        return sample_data, sample_label


class vitaldb_ECGDataset(Dataset):
    """
    Custom PyTorch Dataset for loading ECG signals from an HDF5 file, which is saved from the https://vitaldb.net/.
    Each __getitem__ returns all segments of one case.
    """
    def __init__(self, hdf5_path: str, case_ids_csv_path: str, window_length_samples: int):
        """
        Initialize the ECG dataset.

        Args:
            hdf5_path (str): Path to the HDF5 file with ECG signals.
            case_ids_csv_path (str): Path to the CSV file with processed case IDs.
            window_length_samples (int): Window length in samples.
        """
        super().__init__()
        self.hdf5_path = hdf5_path
        self.window_length = window_length_samples

        if not os.path.exists(self.hdf5_path):
            raise FileNotFoundError(f"HDF5 file not found: {self.hdf5_path}")
        
        if not os.path.exists(case_ids_csv_path):
            raise FileNotFoundError(f"CSV file not found: {case_ids_csv_path}")

        self.hf = h5py.File(self.hdf5_path, 'r')
        
        # Load case IDs as string list
        self.case_ids = pd.read_csv(case_ids_csv_path, dtype={'caseid': str})['caseid'].tolist()
        
        print(f"Dataset initialized with {len(self.case_ids)} cases.")
        print(f"Each case will be split into {self.window_length}-sample windows.")

    def __len__(self) -> int:
        """Return number of cases in the dataset."""
        return len(self.case_ids)

    def __getitem__(self, idx: int):
        """
        Get one case and all its 1-minute segments.

        Args:
            idx (int): Index of the case.

        Returns:
            tuple: 
                - windows (torch.Tensor): Shape (num_windows, window_length, 2).
                - case_id (str): Case ID.
        """
        if not (0 <= idx < len(self.case_ids)):
            raise IndexError(f"Index {idx} is out of bounds (dataset size: {len(self.case_ids)})")

        case_id = self.case_ids[idx]
        
        try:
            # Load signal from HDF5
            signal_data_np = self.hf[case_id][()].T.astype(np.float32)

            num_total_samples = signal_data_np.shape[0]
            num_windows = num_total_samples // self.window_length

            if num_windows == 0:
                print(f"Warning: {case_id} has only {num_total_samples} samples (< {self.window_length}). Returning empty tensor.")
                return torch.empty(0, self.window_length, 2, dtype=torch.float32), case_id

            trimmed_signal_np = signal_data_np[:num_windows * self.window_length, :]
            windows_np = trimmed_signal_np.reshape(num_windows, self.window_length, 2)
            
            windows_tensor = torch.from_numpy(windows_np)
            
            return windows_tensor, case_id

        except Exception as e:
            print(f"Error loading {case_id}: {e}")
            return torch.empty(0, self.window_length, 2, dtype=torch.float32), case_id

    def close(self):
        """Close the HDF5 file."""
        if self.hf:
            self.hf.close()
            print(f"HDF5 file {self.hdf5_path} closed.")





def prepare_dataloader_multiclass(config,
                                  task_name,
                                  dataset_name,
                                  batch_size,
                                  frequency,
                                  length) -> Type[DataLoader]:

    print('Preparing dataloader for multi-classification...')
    normal_index = config.MULTICLASS_LABELS_INDEX["Normal"][dataset_name]
    target_index = config.MULTICLASS_LABELS_INDEX[task_name][dataset_name]

    transformations = prepare_preprocess_multiclass(
        frequency, length, normal_index, target_index, True)
    
    transformations_test = prepare_preprocess_multiclass(
        frequency, length, normal_index, target_index, False)

    if dataset_name == "ptbxl-500":
        root_a = [config.root_ptbxl500 + "all" + config.folder_group[0]] * 6
        root_b = config.npy_files
        ptbxl500_path = [a + b for a, b in zip(root_a, root_b)]
        # train_data_path, train_labels_path, test_data_path, test_labels_path, val_data_path, val_labels_path = config.ptbxl500_path
        (train_data_path, train_labels_path,
        test_data_path, test_labels_path, 
        val_data_path, val_labels_path) = ptbxl500_path
        dataset_train = ptbxldataset_MC(train_data_path, train_labels_path, transformations)
        dataset_valid = ptbxldataset_MC(val_data_path, val_labels_path, transformations)
        dataset_test = ptbxldataset_MC(test_data_path, test_labels_path, transformations_test)

    elif dataset_name == "g12ec":
        root_a = [config.root_g12ec + config.folder_group[0]] * 6
        root_b = config.npy_files
        g12ec_path = [a + b for a, b in zip(root_a, root_b)]
        (train_data_path, train_labels_path,
        test_data_path, test_labels_path, 
        val_data_path, val_labels_path) = g12ec_path
        dataset_train = g12ecdataset_MC(train_data_path, train_labels_path, transformations)
        dataset_valid = g12ecdataset_MC(val_data_path, val_labels_path, transformations)
        dataset_test = g12ecdataset_MC(test_data_path, test_labels_path, transformations_test)

    dataset_train_concat = ConcatDataset([dataset_train, dataset_valid])
    train_dataloader = DataLoader(dataset_train_concat, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    test_dataloader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    # valid_dataloader = DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    labels = dataset_train.labels
    num_samples = labels.shape[0]
    # Extract normal and target dx labels
    normal_labels = labels[:, normal_index]
    target_labels = labels[:, target_index]
    # Validate normal and target dx label do not overlap
    # assert((normal_labels & target_labels).sum() == 0)

    num_normal = normal_labels.sum()
    num_target = target_labels.sum()
    num_others = num_samples - (num_normal + num_target)

    class_weights = [1, num_target/num_normal, num_others/num_normal]
    weight = np.array(class_weights)
    return train_dataloader, test_dataloader, weight


def prepare_preprocess_multiclass(
    frequency: int,
    length: int,
    normal_index: int,
    target_index: int,
    is_train: bool
) -> Type[transforms.Compose]:
    """
    Prepare and compose transform functions.
    Args:
        frequency (int):
        length (int):
        target_index (int):
        normal_index (int):
        is_train (bool):
    Returns:
        composed
    """
    subsample_length = int(frequency * length)
    if is_train:
        composed = transforms.Compose([
            # utils.Subsample(subsample_length),
            utils.ProcessLabel(normal_index, target_index),
            utils.ToTensor()
        ])
    else:
        composed = transforms.Compose([
            # utils.SubsampleEval(subsample_length),
            utils.ProcessLabel(normal_index, target_index),
            utils.ToTensor()
        ])
    return composed



def get_dataloader(mode='pretrain', config=None, dataset_name=None, batch_size=None, shuffle=True):
    if mode == 'finetune':
        if dataset_name is None:
            raise ValueError("Dataset name must be provided for pretrain mode.")

        dataset_map = {
            'ptbxl-500': ptbxldataset,
            'g12ec': g12ecdataset,
        }
                
        if dataset_name == 'ptbxl-500':
            print('Preparing dataloader for ptbxl-500...')
            root_a = [config.root_ptbxl500 + config.TASKS_ptbxl[0] + config.folder_group[0]] * 6
            root_b = config.npy_files
            ptbxl500_path = [a + b for a, b in zip(root_a, root_b)]
            # train_data_path, train_labels_path, test_data_path, test_labels_path, val_data_path, val_labels_path = config.ptbxl500_path
            (train_data_path, train_labels_path,
            test_data_path, test_labels_path, 
            val_data_path, val_labels_path) = ptbxl500_path

            dataset_train = dataset_map[dataset_name](train_data_path, train_labels_path)
            dataset_test = dataset_map[dataset_name](test_data_path, test_labels_path)
            dataset_valid = dataset_map[dataset_name](val_data_path, val_labels_path)
            # dataset_train = utils.clean_ecg_data(dataset_train)
            # dataset_test = utils.clean_ecg_data(dataset_test)
            # dataset_valid = utils.clean_ecg_data(dataset_valid)
            dataset_train = ConcatDataset([dataset_train, dataset_valid])
            
            train_dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=8, pin_memory=True)
            test_dataloader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
            valid_dataloader = DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
            return train_dataloader, test_dataloader, valid_dataloader
        
        elif dataset_name == 'g12ec':
            print('Preparing dataloader for g12ec...')
            root_a = [config.root_g12ec + config.folder_group[0]] * 6
            root_b = config.npy_files
            g12ec_path = [a + b for a, b in zip(root_a, root_b)]
            (train_data_path, train_labels_path,
            test_data_path, test_labels_path, 
            val_data_path, val_labels_path) = g12ec_path

            dataset_train = dataset_map[dataset_name](train_data_path, train_labels_path)
            dataset_test = dataset_map[dataset_name](test_data_path, test_labels_path)
            dataset_valid = dataset_map[dataset_name](val_data_path, val_labels_path)
            dataset_train = utils.clean_ecg_data(dataset_train)
            dataset_test = utils.clean_ecg_data(dataset_test)
            dataset_valid = utils.clean_ecg_data(dataset_valid)
            dataset_train = ConcatDataset([dataset_train, dataset_valid])
            
            train_dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=8, pin_memory=True)
            test_dataloader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
            valid_dataloader = DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
            return train_dataloader, test_dataloader, valid_dataloader


    elif mode == 'pretrain':
        if dataset_name is None:
            raise ValueError("Dataset name must be provided for pretrain mode.")
        
        dataset_map = {
            'mimic': mimic_ecgdataset,
        }
        
        if dataset_name == 'mimic':
            print('Preparing dataloader for mimic...Consdering the large size of the dataset, we divided the dataset into 6 pieces.')
            print('Preparing the 1-3 pieces.')
            dataset1 = dataset_map[dataset_name](config.mimic_path[0], config.mimic_size[0])
            dataset2 = dataset_map[dataset_name](config.mimic_path[1], config.mimic_size[1])
            dataset3 = dataset_map[dataset_name](config.mimic_path[2], config.mimic_size[2])
            print('Preparing the 4-6 pieces.')
            dataset4 = dataset_map[dataset_name](config.mimic_path[3], config.mimic_size[3])
            dataset5 = dataset_map[dataset_name](config.mimic_path[4], config.mimic_size[4])
            dataset6 = dataset_map[dataset_name](config.mimic_path[5], config.mimic_size[5])
            
            dataset = ConcatDataset([dataset1, dataset2, dataset3, dataset4, dataset5, dataset6])
            
            # total_size = len(dataset)
            # train_size = int(0.95 * total_size)
            # val_size = total_size - train_size
            # dataset, val_dataset = random_split(dataset, [train_size, val_size])

            # dataloader
            trian_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=8, pin_memory=True)
            valid_dataloader = DataLoader(dataset6, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
            return trian_dataloader, valid_dataloader
        
        elif dataset_name not in dataset_map:
            raise ValueError(f"Unsupported dataset name '{dataset_name}'. Available options are: {list(dataset_map.keys())}")
        
    else:
        raise ValueError("Invalid mode. Please choose either 'pretrain' or 'finetune'.")