import numpy as np
import pandas as pd
import os.path as osp
import torch
import h5py
from torch.utils.data import Dataset
from scipy.signal import stft
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from . import register_dataset
from .base_dataset import BaseDataset

@register_dataset("HisarMod2019")
class HisarMod2019(BaseDataset):
    @classmethod
    def create(cls, test_size = 0.2, dataset_path=None, *args, **kwargs):
        cls.classes = ['BPSK', 'QPSK', '8PSK', '16PSK', '32PSK', '64PSK', '4QAM', '8QAM', '16QAM', '32QAM', 
                   '64QAM', '128QAM', '256QAM', '2FSK', '4FSK', '8FSK', '16FSK', '4PAM', '8PAM', '16PAM', 'AM-DSB', 
                   'AM-DSB-SC', 'AM-USB', 'AM-LSB', 'FM', 'PM']
        # self.label = torch.tensor(labels, dtype=torch.long)
        if dataset_path is None:
            dataset_path = osp.join(osp.dirname(osp.abspath(__file__)), "HisarMod2019")

        with h5py.File(osp.join(dataset_path, 'HisarMod2019train.h5')) as h5file:
            train = h5file['samples'][:]
            train_label = h5file['labels'][:]
            SNR_tr = h5file['snr'][:]
            h5file.close()

        train, val, train_label, val_label, SNR_tr, SNR_va = train_test_split(train, train_label, SNR_tr, test_size=test_size,
                                                                            random_state=233,
                                                                            stratify=list(zip(train_label,SNR_tr)))

        with h5py.File(osp.join(dataset_path, 'HisarMod2019test.h5')) as h5file:
            test = h5file['samples'][:]
            test_label = h5file['labels'][:]
            SNR_te = h5file['snr'][:]
            h5file.close()

        cls.train_dataset = [[],[],[]]
        cls.val_dataset = [[],[],[]]
        cls.test_dataset = [[],[],[]]

        if "minSNR" in kwargs:
            minSNR = kwargs["minSNR"]
        else:
            minSNR = -20
        if "maxSNR" in kwargs:
            maxSNR = kwargs["maxSNR"]
        else:
            maxSNR = 18

        cls.minSNR = minSNR
        cls.maxSNR = maxSNR
        cls.SNR_list = range(minSNR, maxSNR + 1, 2)

        cls.train_dataset[0].extend(train)
        cls.train_dataset[1].extend(train_label)
        cls.train_dataset[2].extend(SNR_tr)
        cls.val_dataset[0].extend(val)
        cls.val_dataset[1].extend(val_label)
        cls.val_dataset[2].extend(SNR_va)
        cls.test_dataset[0].extend(test)
        cls.test_dataset[1].extend(test_label)
        cls.test_dataset[2].extend(SNR_te)
        cls.dataset = [cls.train_dataset, cls.val_dataset, cls.test_dataset]

    def __init__(self, split="train", mode="default"):
        split_list = ["train", "valid", "test"]
        if not hasattr(HisarMod2019, "train_dataset"):
            raise ValueError("The HisarMod2019 dataset is not created, please use HisarMod2019.create() to create instance.")
        if split not in split_list:
            raise ValueError(f"The split type {split} is not supported!")
        if mode not in ["default", "copy", "stft"]:
            raise ValueError(f"The mode type {mode} is not supported!")
        
        self.split_id = split_list.index(split)
        self.split = split
        self.mode = mode
        # if mode == "stft":
        self.STFTs = []
        with tqdm(total=len(self.dataset[self.split_id][0])) as t:
            t.set_description('Generating STFT:')
            for _, IQ in enumerate(self.dataset[self.split_id][0]):
                _, _, stp = stft(IQ[0,:], 1.0, 'blackman',31, 30, 128)
                self.STFTs.append(np.expand_dims(stp[:32,:], 0))
                t.update(1)

    def __len__(self):
        if self.split == "train":
            return len(self.train_dataset[0])
        elif self.split == "valid":
            return len(self.val_dataset[0])
        elif self.split == "test":
            return len(self.test_dataset[0])

    def __getitem__(self, idx):
        if self.mode == "default":
            return torch.Tensor(self.dataset[self.split_id][0][idx]),\
                torch.Tensor(self.STFTs[idx]),\
                torch.tensor(self.dataset[self.split_id][1][idx], dtype=torch.long),\
                self.dataset[self.split_id][2][idx]
        elif self.mode == "copy":
            return torch.Tensor(self.dataset[self.split_id][0][idx], dtype=torch.float),\
                torch.tensor(self.dataset[self.split_id][1][idx], dtype=torch.long),\
                self.dataset[self.split_id][2][idx]
        elif self.mode == "stft":
            return torch.Tensor(self.STFTs[idx]),\
                torch.tensor(self.dataset[self.split_id][1][idx], dtype=torch.long),\
                self.dataset[self.split_id][2][idx]

    @property
    def get_pretrain_data(self):
        return np.array(self.train_dataset[0]), np.array(self.train_dataset[1]), np.array(self.train_dataset[2])
