import os
import numpy as np
import pandas as pd
import logging
from sklearn import model_selection
from scipy import signal
from scipy.signal import resample
from Models.connectivity import get_pseudolabels
import torch
import pickle
logger = logging.getLogger(__name__)


def load(config):
    # Build data
    Data = {}
    if os.path.exists(config['data_dir'] + '/' + config['problem'] + '.npy'):
        logger.info("Loading preprocessed data ...")
        Data_npy = np.load(config['data_dir'] + '/' + config['problem'] + '.npy', allow_pickle=True)
        if np.any(Data_npy.item().get('val_data')):
            Data['train_data'] = Data_npy.item().get('train_data')
            Data['train_label'] = Data_npy.item().get('train_label')
            Data['val_data'] = Data_npy.item().get('val_data')
            Data['val_label'] = Data_npy.item().get('val_label')
            Data['test_data'] = Data_npy.item().get('test_data')
            Data['test_label'] = Data_npy.item().get('test_label')
            Data['max_len'] = Data['train_data'].shape[1]
            Data['All_train_data'] = Data_npy.item().get('All_train_data')
            Data['All_train_label'] = Data_npy.item().get('All_train_label')
        else:
            Data['train_data'], Data['train_label'], Data['val_data'], Data['val_label'] = \
                split_dataset(Data_npy.item().get('train_data'), Data_npy.item().get('train_label'), 0.1)
            Data['All_train_data'] = Data_npy.item().get('train_data')
            Data['All_train_label'] = Data_npy.item().get('train_label')
            Data['test_data'] = Data_npy.item().get('test_data')
            Data['test_label'] = Data_npy.item().get('test_label')
            Data['max_len'] = Data['train_data'].shape[2]
    if config['problem'] == 'STEW':
        Data['train_data'] = apply_bipolar_montage_14ch(Data['train_data'])
        Data['val_data'] = apply_bipolar_montage_14ch(Data['val_data'])
        Data['test_data'] = apply_bipolar_montage_14ch(Data['test_data'])
        Data['All_train_data'] = apply_bipolar_montage_14ch(Data['All_train_data'])
        Data['coherence_labels'] = load_coherence_labels(config['data_dir'], Data['All_train_data'], config, 128, single_file=True)

    logger.info("{} samples will be used for self-supervised training".format(len(Data['All_train_label'])))
    logger.info("{} samples will be used for fine tuning ".format(len(Data['train_label'])))
    samples, channels, time_steps = Data['train_data'].shape
    logger.info(
        "Train Data Shape is #{} samples, {} channels, {} time steps ".format(samples, channels, time_steps))
    logger.info("{} samples will be used for validation".format(len(Data['val_label'])))
    logger.info("{} samples will be used for test".format(len(Data['test_label'])))

    return Data

def load_coherence_labels(root, files, config, default_rate, single_file=False):
    if os.path.exists(root + '/coherence_labels.npy'):
        logger.info("Loading saved coherence labels ...")
        res = np.load(root + '/coherence_labels.npy', allow_pickle=True)
    else:
        logger.info("Generating coherence labels ...")
        res = []
        for l in files:
            if single_file:
                X = l
            else:
                sample = pickle.load(open(os.path.join(root, l), "rb"))
                X = sample["X"]
            if config['sampling_rate'] != default_rate:
                X = resample(X, 10 * config['sampling_rate'], axis=-1)
            X = X / (
                np.quantile(np.abs(X), q=0.95, method="linear", axis=-1, keepdims=True)
                + 1e-8
            )
            X = torch.FloatTensor(X)
            res.append(get_pseudolabels(X.unsqueeze(0), D = config['patch_size'], is_mixed=config['mixed'], label_type=config['label_type'], sampling_rate=config['sampling_rate'], temperature=config['temperature']))
        np.save(root + '/coherence_labels.npy', res, allow_pickle=True)
            
    return res

def tuev_loader(config):
    root = config['data_dir'] + '/edf'
    pretrain_files = os.listdir(os.path.join(root, "processed_train"))
    train_sub = list(set([f.split("_")[0] for f in pretrain_files]))
    print("train sub", len(train_sub))
    test_files = os.listdir(os.path.join(root, "processed_eval"))

    val_sub = np.random.choice(train_sub, size=int(
        len(train_sub) * 0.1), replace=False)
    train_sub = list(set(train_sub) - set(val_sub))
    val_files = [f for f in pretrain_files if f.split("_")[0] in val_sub]
    train_files = [f for f in pretrain_files if f.split("_")[0] in train_sub]
    TrainLoader = TUEVLoader(os.path.join(root, "processed_train"), train_files, config['sampling_rate'])
    PretrainLoader = TUEVLoader(os.path.join(root, "processed_train"), pretrain_files, config['sampling_rate'])
    EvalLoader = TUEVLoader(os.path.join(root, "processed_train"), val_files, config['sampling_rate'])
    TestLoader = TUEVLoader(os.path.join(root, "processed_eval"), test_files, config['sampling_rate'])
    
    return {'pretrain_loader': PretrainLoader, 'train_loader': TrainLoader, 'eval_loader': EvalLoader, 'test_loader': TestLoader}

def tuab_loader(config):
    root = config['data_dir'] + '/edf/processed'
    train_files = os.listdir(os.path.join(root, "train"))
    val_files = os.listdir(os.path.join(root, "val"))
    test_files = os.listdir(os.path.join(root, "test"))

    TrainLoader = TUABLoader(os.path.join(root, "train"), train_files, config['sampling_rate'])
    PretrainLoader = TUABLoader(os.path.join(root, "train"), train_files, config['sampling_rate'], os.path.join(root, "val"), val_files)
    EvalLoader = TUABLoader(os.path.join(root, "val"), val_files, config['sampling_rate'])
    TestLoader = TUABLoader(os.path.join(root, "test"), test_files, config['sampling_rate'])

    return {'pretrain_loader': PretrainLoader, 'train_loader': TrainLoader, 'eval_loader': EvalLoader, 'test_loader': TestLoader}

def chbmit_loader(config):
    root = config['data_dir'] + '/clean_segments'
    train_files = os.listdir(os.path.join(root, "train"))
    val_files = os.listdir(os.path.join(root, "val"))
    test_files = os.listdir(os.path.join(root, "test"))
    test_files = [l for l in test_files if pickle.load(open(os.path.join(os.path.join(root, "test"), l), "rb"))['X'].shape == (16, 2560)]
    TrainLoader = CHBMITLoader(os.path.join(root, "train"), train_files, config['sampling_rate'])
    PretrainLoader = CHBMITLoader(os.path.join(root, "train"), train_files, config['sampling_rate'], os.path.join(root, "val"), val_files)
    EvalLoader = CHBMITLoader(os.path.join(root, "val"), val_files, config['sampling_rate'])
    TestLoader = CHBMITLoader(os.path.join(root, "test"), test_files, config['sampling_rate'])
    return {'pretrain_loader': PretrainLoader, 'train_loader': TrainLoader, 'eval_loader': EvalLoader, 'test_loader': TestLoader}

def cross_domain_loader(config):
    root_tuab = '../Dataset/TUAB/edf/processed'
    train_files_tuab = os.listdir(os.path.join(root_tuab, "train"))
    val_files_tuab = os.listdir(os.path.join(root_tuab, "val"))
    test_files_tuab = os.listdir(os.path.join(root_tuab, "test"))

    root_chb = '../Dataset/CHB-MIT/clean_segments'
    train_files_chb = os.listdir(os.path.join(root_chb, "train"))
    val_files_chb = os.listdir(os.path.join(root_chb, "val"))
    test_files_chb = os.listdir(os.path.join(root_chb, "test"))
    test_files_chb = [l for l in test_files_chb if pickle.load(open(os.path.join(os.path.join(root_chb, "test"), l), "rb"))['X'].shape == (16, 2560)]

    root_tuev =  '../Dataset/TUEV/edf'
    pretrain_files_tuev = os.listdir(os.path.join(root_tuev, "processed_train"))
    train_sub_tuev = list(set([f.split("_")[0] for f in pretrain_files_tuev]))
    print("train sub", len(train_sub_tuev))
    test_files_tuev = os.listdir(os.path.join(root_tuev, "processed_eval"))
    val_sub_tuev = np.random.choice(train_sub_tuev, size=int(
        len(train_sub_tuev) * 0.1), replace=False)
    train_sub_tuev = list(set(train_sub_tuev) - set(val_sub_tuev))
    val_files_tuev = [f for f in pretrain_files_tuev if f.split("_")[0] in val_sub_tuev]
    train_files_tuev = [f for f in pretrain_files_tuev if f.split("_")[0] in train_sub_tuev]
    
    PretrainLoader = CrossDomainLoader([os.path.join(root_tuab, "train"), os.path.join(root_tuab, "val"), os.path.join(root_tuev, "processed_train")], [train_files_tuab, val_files_tuab, pretrain_files_tuev], [200, 200, 250], ['X', 'X', 'signal'], [1, 1, 2])
    if config['problem'] == 'TUAB':
        TrainLoader = TUABLoader(os.path.join(root_tuab, "train"), train_files_tuab, config['sampling_rate'])
        EvalLoader = TUABLoader(os.path.join(root_tuab, "val"), val_files_tuab, config['sampling_rate'])
        TestLoader = TUABLoader(os.path.join(root_tuab, "test"), test_files_tuab, config['sampling_rate'])
    elif config['problem'] =='CHB-MIT':
        TrainLoader = CHBMITLoader(os.path.join(root_chb, "train"), train_files_chb, config['sampling_rate'])
        EvalLoader = CHBMITLoader(os.path.join(root_chb, "val"), val_files_chb, config['sampling_rate'])
        TestLoader = CHBMITLoader(os.path.join(root_chb, "test"), test_files_chb, config['sampling_rate'])
    elif config['problem'] =='TUEV':
        TrainLoader = TUEVLoader(os.path.join(root_tuev, "processed_train"), train_files_tuev, config['sampling_rate'])
        EvalLoader = TUEVLoader(os.path.join(root_tuev, "processed_train"), val_files_tuev, config['sampling_rate'])
        TestLoader = TUEVLoader(os.path.join(root_tuev, "processed_eval"), test_files_tuev, config['sampling_rate'])
    else:
        raise NotImplementedError

    return {'pretrain_loader': PretrainLoader, 'train_loader': TrainLoader, 'eval_loader': EvalLoader, 'test_loader': TestLoader}

class TUABLoader(torch.utils.data.Dataset):
    def __init__(self, root, files, sampling_rate=256, root2=None, files2=[]):
        self.root = root
        self.files = files
        self.root2 = root2
        self.files2 = files2
        self.default_rate = 200
        self.sampling_rate = sampling_rate

    def __len__(self):
        return len(self.files) + len(self.files2)

    def __getitem__(self, index):
        if index < len(self.files):
            sample = pickle.load(open(os.path.join(self.root, self.files[index]), "rb"))
        else:
            sample = pickle.load(open(os.path.join(self.root2, self.files2[index - len(self.files)]), "rb"))

        X = sample["X"]
        # from default 200Hz to self.sampling_rate
        if self.sampling_rate != self.default_rate:
            X = resample(X, 10 * self.sampling_rate, axis=-1)
        X = X / (
            np.quantile(np.abs(X), q=0.95, method="linear", axis=-1, keepdims=True)
            + 1e-8
        )
        Y = sample["y"]
        X = torch.FloatTensor(X)
        return X, Y, index

class CrossDomainLoader(torch.utils.data.Dataset):
    def __init__(self, root, files, sampling_rates, sample_names, steps, sampling_rate=256):
        self.root = root
        self.sampling_rates = sampling_rates
        self.sample_names = sample_names
        self.sampling_rate = sampling_rate
        self.lens = [len(files[i])//steps[i] for i in range(len(files))]
        self.files = files
        self.steps = steps
        self.cumulative_lens = [0] + list(np.cumsum(self.lens))
        self.sampling_rate = sampling_rate

    def __len__(self):
        return self.cumulative_lens[-1]

    def __getitem__(self, index):
        if not (0 <= index < self.__len__()):
            raise IndexError("Index out of bounds")

        domain_idx = np.searchsorted(self.cumulative_lens, index, side='right') - 1
        
        file_idx = index - self.cumulative_lens[domain_idx]

        file_path = os.path.join(self.root[domain_idx], self.files[domain_idx][self.steps[domain_idx]*file_idx])
        
        try:
            with open(file_path, "rb") as f:
                sample = pickle.load(f)
        except (FileNotFoundError, pickle.PicklingError) as e:
            print(f"Error loading file: {file_path}. Skipping. Error: {e}")
            raise
        X = sample[self.sample_names[domain_idx]]
        if self.steps[domain_idx] == 2:
            file_path2 = os.path.join(self.root[domain_idx], self.files[domain_idx][2*file_idx+1])
            
            try:
                with open(file_path2, "rb") as f:
                    sample2 = pickle.load(f)
            except (FileNotFoundError, pickle.PicklingError) as e:
                print(f"Error loading file: {file_path2}. Skipping. Error: {e}")
                raise
            X2 = sample2[self.sample_names[domain_idx]]
            X = np.concatenate((X, X2), axis=1)

        if self.sampling_rates[domain_idx] != self.sampling_rate:
            original_rate = self.sampling_rates[domain_idx]
            time_points = int(X.shape[-1] / original_rate * self.sampling_rate)
            X = resample(X, time_points, axis=-1)
        
        X = X / (
            np.quantile(np.abs(X), q=0.95, method="linear", axis=-1, keepdims=True)
            + 1e-8
        )
        
        X = torch.FloatTensor(X)
        
        return X, index, index
    
class CHBMITLoader(torch.utils.data.Dataset):
    def __init__(self, root, files, sampling_rate=256, root2=None, files2=[]):
        self.root = root
        self.files = files
        self.root2 = root2
        self.files2 = files2
        self.default_rate = 256
        self.sampling_rate = sampling_rate

    def __len__(self):
        return len(self.files) + len(self.files2)

    def __getitem__(self, index):
        if index < len(self.files):
            sample = pickle.load(open(os.path.join(self.root, self.files[index]), "rb"))
        else:
            sample = pickle.load(open(os.path.join(self.root2, self.files2[index - len(self.files)]), "rb"))

        X = sample["X"]
        # from default 256Hz to self.sampling_rate
        if self.sampling_rate != self.default_rate:
            X = resample(X, 10 * self.sampling_rate, axis=-1)
        X = X / (
            np.quantile(np.abs(X), q=0.95, method="linear", axis=-1, keepdims=True)
            + 1e-8
        )
        Y = sample["y"]
        X = torch.FloatTensor(X)
        return X, Y, index
    
class TUEVLoader(torch.utils.data.Dataset):
    def __init__(self, root, files, sampling_rate=256):
        self.root = root
        self.files = files
        self.default_rate = 250
        self.sampling_rate = sampling_rate

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        sample = pickle.load(open(os.path.join(self.root, self.files[index]), "rb"))
        X = sample["signal"]
        if self.sampling_rate != self.default_rate:
            X = resample(X, 5 * self.sampling_rate, axis=-1)
        X = X / (
            np.quantile(np.abs(X), q=0.95, method="linear", axis=-1, keepdims=True)
            + 1e-8
        )
        Y = int(sample["label"][0] - 1)
        X = torch.FloatTensor(X)
        return X, Y, index

def apply_bipolar_montage_14ch(data):
    """
    Apply double banana bipolar montage to 14-channel EEG data.
    """
    
    # Define channel indices
    ch_indices = {
        'AF3': 0, 'F7': 1, 'F3': 2, 'FC5': 3, 'T7': 4, 'P7': 5, 'O1': 6,
        'O2': 7, 'P8': 8, 'T8': 9, 'FC6': 10, 'F4': 11, 'F8': 12, 'AF4': 13
    }
    
    # Initialize output array with same shape
    if isinstance(data, torch.Tensor):
        bipolar_data = torch.zeros_like(data)
    else:
        bipolar_data = np.zeros_like(data)
    
    # Define bipolar montage pairs
    bipolar_pairs = [
        # Left temporal chain (0-3)
        (ch_indices['AF3'], ch_indices['F7']),   # AF3-F7
        (ch_indices['F7'], ch_indices['T7']),    # F7-T7
        (ch_indices['T7'], ch_indices['P7']),    # T7-P7
        (ch_indices['P7'], ch_indices['O1']),    # P7-O1
        
        # Right temporal chain (4-7)
        (ch_indices['AF4'], ch_indices['F8']),   # AF4-F8
        (ch_indices['F8'], ch_indices['T8']),    # F8-T8
        (ch_indices['T8'], ch_indices['P8']),    # T8-P8
        (ch_indices['P8'], ch_indices['O2']),    # P8-O2
        
        # Left parasagittal (8-9)
        (ch_indices['AF3'], ch_indices['F3']),   # AF3-F3
        (ch_indices['F3'], ch_indices['FC5']),   # F3-FC5
        
        # Right parasagittal (10-11)
        (ch_indices['AF4'], ch_indices['F4']),   # AF4-F4
        (ch_indices['F4'], ch_indices['FC6']),   # F4-FC6
        
        # Interhemispheric connections (12-13)
        (ch_indices['F3'], ch_indices['F4']),    # F3-F4
        (ch_indices['FC5'], ch_indices['FC6']),  # FC5-FC6
    ]
    
    # Apply bipolar montage
    for i, (ch1_idx, ch2_idx) in enumerate(bipolar_pairs):
        bipolar_data[:, i, :] = data[:, ch1_idx, :] - data[:, ch2_idx, :]
    return bipolar_data


def split_dataset(data, label, validation_ratio):
    splitter = model_selection.StratifiedShuffleSplit(n_splits=1, test_size=validation_ratio, random_state=1234)
    train_indices, val_indices = zip(*splitter.split(X=np.zeros(len(label)), y=label))
    train_data = data[train_indices]
    train_label = label[train_indices]
    val_data = data[val_indices]
    val_label = label[val_indices]
    return train_data, train_label, val_data, val_label