from torchvision import datasets, transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.nn.utils import weight_norm
import scipy.io as sio
import numpy as np
from os import path
from .data_importers import get_data
from .filters import bandpass_torch
from sklearn.model_selection import train_test_split



class TransformSubset(torch.utils.data.Dataset):
    """
    Subset of a dataset at specified indices.

    Arguments:
        dataset (Dataset): The whole dataset.
        indices (sequence): Indices in the whole set selected for subset.
    """
    def __init__(self, dataset, indices, transform=None):
        self.dataset   = dataset
        self.indices   = indices
        self.transform = transform

    def __getitem__(self, idx):
        if self.dataset.transform != self.transform:
            self.dataset.transform = self.transform
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)


class PhysionetMMMI(torch.utils.data.Dataset):

    def __init__(self, datapath, num_classes=3, transform=None):
        self.datapath = datapath
        self.transform = transform
        self.num_classes = num_classes
        if not path.isfile(path.join(datapath, f'{num_classes}class.npz')):
            print("npz file not existing. Load .edf and save data in npz files for faster loading of data next time.")
            X, y = get_data(datapath, n_classes=num_classes)
            np.savez(path.join(datapath,f'{num_classes}class'), X = X, y = y)
        npzfile = np.load(path.join(datapath, f'{num_classes}class.npz'))
        X, y = npzfile['X'], npzfile['y']
        self.samples = torch.Tensor(X).to(dtype=torch.float)
        self.labels = torch.Tensor(y).to(dtype=torch.long)
        self.test_samples = torch.Tensor([0]).to(dtype=torch.float)
        self.test_labels = torch.Tensor([0]).to(dtype=torch.long)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx, :, :]
        label = self.labels[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample, label
        
    def ignore_trials(self,target=2):
        trial_list = torch.where(self.labels!=target)[0]
        self.samples = self.samples[trial_list]
        self.labels = self.labels[trial_list]
    
    def bandpass_torch(self, low_f, high_f, device):
        self.samples = bandpass_torch(self.samples, low_f, high_f,fs_eeg=160, device = device)
    
    def remove_target(self, model, target, device, frequency_band=[0.1,40]):
        trial_list = torch.zeros(self.samples.shape[0]).to(device)

        for ind,sample in enumerate(self.samples):
            sample = sample[(None,)*2].float().to(device)
            result = model(bandpass_torch(sample,frequency_band[0], frequency_band[1], device)).argmax()
            if result.item()!=target:
                trial_list[ind] = True
                
        self.samples = self.samples[trial_list.bool()]
        self.labels = self.labels[trial_list.bool()]
    
    def separate_test_labels(self, test_size=0.15, random_state=0):
        self.samples, self.test_samples, self.labels, self.test_labels = train_test_split(self.samples, self.labels, test_size = test_size,random_state = random_state)


class ReshapeTensor(object):
    def __call__(self, sample):
        return sample.view(1, sample.shape[0], sample.shape[1])
