from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
import os
from Asymmetric_Noise import *
import mne

import sys
sys.path.append('../../')

from pipeline.ca_database_api import DataHandler

## If you want to use the weights and biases 
# import wandb
# wandb.init(project="noisy-label-project", entity="....")


def unpickle(file):
    import _pickle as cPickle
    with open(file, 'rb') as fo:
        dict = cPickle.load(fo, encoding='latin1')
    return dict

def get_input_size(args):
    train_data_handler = DataHandler(
        database_save_dir=args.database_save_dir,
        data_name=args.data_name,
        exp_id=args.exp_id,
        patient_list=args.train_patient_list,
        noise_ratio=args.noise_ratio,
        window_time=args.window_time,
        slide_time=args.slide_time,
        num_level=args.num_level,
    )
    train_data_pack = train_data_handler.get_data()
    x_train = train_data_pack.data.reshape(-1, *train_data_pack.data.shape[-2:])
    y_train = train_data_pack.label.reshape(-1)
    n_class = len(np.unique(y_train))
    num_samples = len(x_train)

    return x_train.shape[-2], n_class, num_samples


def train_dataset(args):
    print("Loading the training dataset...")
    train_data_handler = DataHandler(
        database_save_dir=args.database_save_dir,
        data_name=args.data_name,
        exp_id=args.exp_id,
        patient_list=args.train_patient_list,
        noise_ratio=args.noise_ratio,
        window_time=args.window_time,
        slide_time=args.slide_time,
        num_level=args.num_level,
    )
    train_data_pack = train_data_handler.get_data()
    x_train = train_data_pack.data.reshape(-1, *train_data_pack.data.shape[-2:])
    y_train = train_data_pack.label.reshape(-1)
    n_class = len(np.unique(y_train))
    del train_data_pack, train_data_handler

    return x_train, y_train

def valid_dataset(args):
    print("Loading the validing dataset...")
    valid_data_handler = DataHandler(
        database_save_dir=args.database_save_dir,
        data_name=args.data_name,
        exp_id=args.exp_id,
        patient_list=args.valid_patient_list,
        noise_ratio=args.noise_ratio,
        window_time=args.window_time,
        slide_time=args.slide_time,
        num_level=args.num_level,
    )
    valid_data_pack = valid_data_handler.get_data()
    x_valid = valid_data_pack.data.reshape(-1, *valid_data_pack.data.shape[-2:])
    y_valid = valid_data_pack.label.reshape(-1)
    del valid_data_pack, valid_data_handler

    return x_valid, y_valid

def test_dataset(args):
    print("Loading the testing dataset...")
    test_data_handler = DataHandler(
        database_save_dir=args.database_save_dir,
        data_name=args.data_name,
        exp_id=args.exp_id,
        patient_list=args.test_patient_list,
        noise_ratio=0,
        window_time=args.window_time,
        slide_time=args.slide_time,
        num_level=args.num_level,
    )
    test_data_pack = test_data_handler.get_data()
    x_test = test_data_pack.data.reshape(-1, *test_data_pack.data.shape[-2:])
    y_test = test_data_pack.label.reshape(-1)
    del test_data_pack, test_data_handler
    
    return x_test, y_test

# data augmentation
def time_domain_augmentation(data):
    return data + np.random.normal(0, 1e-6, size=data.shape)

def freq_domain_augmentation(data):
    fmin, fmax = 2, 40
    n_fft = 256
    tfr = mne.time_frequency.tfr_array_multitaper(data, sfreq=250, freqs=np.arange(fmin, fmax + 1, 2), n_jobs=1, n_fft=n_fft)
    tfr += np.random.normal(0, 1e-9, size=tfr.shape)
    augmented_data = mne.time_frequency.tfr_array_morlet(tfr, sfreq=250, freqs=np.arange(fmin, fmax + 1, 2), output='power')

    return augmented_data

def inject_eye_blink_artifacts(data, blink_channels=[0, 1], blink_frequency=10, blink_amplitude=100):
    augmented_eeg = data.copy()
    
    for channel in blink_channels:
        # Generate a sine wave representing an eye blink artifact
        time_points = np.arange(len(data[channel]))
        blink_signal = blink_amplitude * np.sin(2 * np.pi * blink_frequency * time_points / len(time_points))
        
        augmented_eeg[channel] += blink_signal
    
    return augmented_eeg

def apply_channel_dropout(data, dropout_fraction=0.1):
    num_channels = data.shape[0]
    num_channels_to_drop = int(num_channels * dropout_fraction)
    
    channels_to_drop = np.random.choice(np.arange(num_channels), size=num_channels_to_drop, replace=False)
    
    augmented_eeg = data.copy()
    augmented_eeg[channels_to_drop, :] = 0
    
    return augmented_eeg


class blurred_dataset(Dataset): 
    def __init__(self, args, sample_ratio, root_dir, mode, probability=[], log=''): 
        
        self.r = args.noise_ratio # noise ratio
        self.sample_ratio = sample_ratio
        self.mode = mode
        root_dir_save = root_dir

        num_sample     = args.num_samples
        self.class_ind = {}

        if self.mode == 'test':
            x_test, y_test = valid_dataset(args)
            # self.test_data = x_test.transpose((0, 2, 1)) # after transpose: (73440, 25, 8)
            self.test_data = x_test
            self.test_label = y_test
        else:    
            x_train, y_train = train_dataset(args)
            # train_data = x_train.transpose((0, 2, 1))
            train_data = x_train
            train_label = y_train
            noise_label = train_label

            num_class = len(np.unique(y_train))
            for kk in range(num_class):
                self.class_ind[kk] = [i for i,x in enumerate(noise_label) if x==kk]

            if self.mode == 'all':
                self.train_data = train_data
                self.noise_label = noise_label
                
            else:
                save_file = 'Clean_index.npz'
                save_file = os.path.join(root_dir_save, save_file)

                if self.mode == "labeled":
                    pred_idx  = np.zeros(int(self.sample_ratio*num_sample))
                    class_len = int(self.sample_ratio*num_sample/num_class)
                    size_pred = 0

                    ## Ranking-based Selection and Introducing Class Balance
                    for i in range(num_class):            
                        class_indices = self.class_ind[i]
                        prob1  = np.argsort(probability[class_indices].cpu().numpy())
                        size1 = len(class_indices)

                        try:
                            pred_idx[size_pred:size_pred+class_len] = np.array(class_indices)[prob1[0:class_len].astype(int)].squeeze()
                            size_pred += class_len
                        except:                            
                            pred_idx[size_pred:size_pred+size1] = np.array(class_indices)
                            size_pred += size1
                    
                    pred_idx = [int(x) for x in list(pred_idx)]
                    np.savez(save_file, index = pred_idx)

                    ## Weights for label refinement
                    probability[probability<0.5] = 0
                    self.probability = [1-probability[i] for i in pred_idx]

                elif self.mode == "unlabeled":
                    pred_idx = np.load(save_file)['index']
                    idx = list(range(num_sample))
                    pred_idx_noisy = [x for x in idx if x not in pred_idx]        
                    pred_idx = pred_idx_noisy   
                
                self.train_data = train_data[pred_idx]
                self.noise_label = [noise_label[i] for i in pred_idx]                                 

    def __getitem__(self, index):
        if self.mode=='labeled':
            seq, target, prob = self.train_data[index], self.noise_label[index], self.probability[index]
            seq2 = time_domain_augmentation(seq)
            seq3 = time_domain_augmentation(seq)
            seq4 = time_domain_augmentation(seq)

            return seq, seq2, seq3, seq4,  target, prob   

        elif self.mode=='unlabeled':
            seq = self.train_data[index]
            seq2 = time_domain_augmentation(seq)
            seq3 = time_domain_augmentation(seq)
            seq4 = time_domain_augmentation(seq)
            return seq, seq2, seq3, seq4

        elif self.mode=='all':
            seq, target = self.train_data[index], self.noise_label[index]
            # seq = Image.fromarray(seq)
            # seq = self.transform(seq)            
            return seq, target, index

        elif self.mode=='test':
            seq, target = self.test_data[index], self.test_label[index]
            # seq = Image.fromarray(seq)
            # seq = self.transform(seq)            
            return seq, target
           
    def __len__(self):
        if self.mode!='test':
            return len(self.train_data)
        else:
            return len(self.test_data)   
        
class blurred_dataloader():  
    def __init__(self, root_dir, log):
        self.root_dir = root_dir
        self.log = log
        
                   
    def run(self, args, sample_ratio, mode, prob=[]):
        if mode=='warmup':
            all_dataset = blurred_dataset(args, sample_ratio=sample_ratio, root_dir=self.root_dir, mode="all")                
            trainloader = DataLoader(
                dataset=all_dataset, 
                batch_size=args.batch_size*2,
                shuffle=True,
                num_workers=4)             
            return trainloader
                                     
        elif mode=='train':
            labeled_dataset = blurred_dataset(args, sample_ratio=sample_ratio, root_dir=self.root_dir, mode="labeled", probability=prob,log=self.log)              
            labeled_trainloader = DataLoader(
                dataset=labeled_dataset, 
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=4, drop_last=True)  

            unlabeled_dataset = blurred_dataset(args, sample_ratio=sample_ratio, root_dir=self.root_dir, mode="unlabeled")                    
            unlabeled_trainloader = DataLoader(
                dataset=unlabeled_dataset, 
                batch_size= int(args.batch_size/(2*sample_ratio)),
                shuffle=True,
                num_workers=4, drop_last =True)    

            return labeled_trainloader, unlabeled_trainloader                
        
        elif mode=='test':
            test_dataset = blurred_dataset(args, sample_ratio=sample_ratio, root_dir=self.root_dir, mode='test')      
            test_loader = DataLoader(
                dataset=test_dataset, 
                batch_size=100,
                shuffle=False,
                num_workers=4)          
            return test_loader
        
        elif mode=='eval_train':
            eval_dataset = blurred_dataset(args, sample_ratio=sample_ratio, root_dir=self.root_dir, mode='all')      
            eval_loader = DataLoader(
                dataset=eval_dataset, 
                batch_size=100,
                shuffle=False,
                num_workers=4, drop_last= True)          
            return eval_loader        
