import h5py
import numpy as np
import pandas as pd
import os.path as osp
import torch
from torch.utils.data import Dataset
from scipy.signal import stft
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from . import register_dataset
from .base_dataset import BaseDataset

@register_dataset("RML2018")
class RML2018(BaseDataset):
    @classmethod
    def create(cls, test_size=0.2, dataset_path=None, *args, **kwargs):
        # self.label = torch.tensor(labels, dtype=torch.long)
        cls.classes = ['OOK','4ASK','8ASK',
               'BPSK','QPSK','8PSK','16PSK','32PSK',
               '16APSK','32APSK','64APSK','128APSK',
               '16QAM','32QAM','64QAM','128QAM','256QAM',
               'AM-SSB-WC','AM-SSB-SC','AM-DSB-WC','AM-DSB-SC',
               'FM','GMSK','OQPSK']
        if dataset_path is None:
            dataset_path = osp.join(osp.dirname(osp.abspath(__file__)), "RML2018.01")
        with h5py.File(dataset_path + '/GOLD_XYZ_OSC.0001_1024.hdf5', 'r') as f:
            IQ_data = f['X'][:, :, :]
            class_label = f['Y'][:, :]
            SNR_label = f['Z'][:]
        IQ_data = IQ_data.transpose(0, 2, 1)
        class_labels = np.argmax(class_label, axis=1)
        SNR_labels = SNR_label.squeeze()

        cls.train_dataset = [[],[],[]]
        cls.val_dataset = [[],[],[]]
        cls.test_dataset = [[],[],[]]

        if "minSNR" in kwargs:
            minSNR = kwargs["minSNR"]
        else:
            minSNR = -20
        if "maxSNR" in kwargs:
            maxSNR = kwargs["maxSNR"]
        else:
            maxSNR = 30

        cls.minSNR = minSNR
        cls.maxSNR = maxSNR
        cls.SNR_list = range(minSNR, maxSNR + 1, 2)

        X, x, Y, y, SNR_tr, SNR_te = train_test_split(IQ_data, class_labels, SNR_labels, test_size=test_size,
                                                    random_state=233,
                                                    stratify=class_labels)
        train, val, train_label, val_label, SNR_tr, SNR_va = train_test_split(X, Y, SNR_tr, test_size=0.25,
                                                                            random_state=233,
                                                                            stratify=Y)
        cls.train_dataset[0].extend(train)
        cls.train_dataset[1].extend(train_label)
        cls.train_dataset[2].extend(SNR_tr)
        cls.val_dataset[0].extend(val)
        cls.val_dataset[1].extend(val_label)
        cls.val_dataset[2].extend(SNR_va)
        cls.test_dataset[0].extend(x)
        cls.test_dataset[1].extend(y)
        cls.test_dataset[2].extend(SNR_te)
        cls.dataset = [cls.train_dataset, cls.val_dataset, cls.test_dataset]

    def __init__(self, split="train", mode="default"):
        split_list = ["train", "valid", "test"]
        if not hasattr(RML2018, "train_dataset"):
            raise ValueError("The RML2018 dataset is not created, please use RML2018.create() to create instance.")
        if split not in split_list:
            raise ValueError(f"The split type {split} is not supported!")
        if mode not in ["default", "copy", "stft"]:
            raise ValueError(f"The mode type {mode} is not supported!")
        
        self.split_id = split_list.index(split)
        self.split = split
        self.mode = mode
        if mode == "stft":
            self.STFTs = []
            with tqdm(total=len(self.dataset[self.split_id][0])) as t:
                t.set_description('Generating STFT:')
                for _, IQ in enumerate(self.dataset[self.split_id][0]):
                    _, _, stp = stft(IQ[0,:], 1.0, 'blackman',31, 30, 128)
                    self.STFTs.append(np.expand_dims(stp[:32,:], 0))
                    t.update(1)

    def __len__(self):
        if self.split == "train":
            return len(self.train_dataset[0])
        elif self.split == "valid":
            return len(self.val_dataset[0])
        elif self.split == "test":
            return len(self.test_dataset[0])

    def __getitem__(self, idx):
        if self.mode == "default":
            return torch.Tensor(self.dataset[self.split_id][0][idx]),\
                torch.Tensor([]),\
                torch.tensor(self.dataset[self.split_id][1][idx], dtype=torch.long),\
                self.dataset[self.split_id][2][idx]
        elif self.mode == "copy":
            return torch.Tensor(self.dataset[self.split_id][0][idx], dtype=torch.float),\
                torch.tensor(self.dataset[self.split_id][1][idx], dtype=torch.long),\
                self.dataset[self.split_id][2][idx]
        elif self.mode == "stft":
            return torch.Tensor(self.STFTs[idx]),\
                torch.tensor(self.dataset[self.split_id][1][idx], dtype=torch.long),\
                self.dataset[self.split_id][2][idx]

    @property
    def get_pretrain_data(self):
        return np.array(self.train_dataset[0]), np.array(self.train_dataset[1]), np.array(self.train_dataset[2])
