import numpy as np
import pandas as pd
import os.path as osp
import torch
from torch.utils.data import Dataset
from scipy.signal import stft
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from . import register_dataset
from .base_dataset import BaseDataset

@register_dataset("RML2016")
class RML2016(BaseDataset):
    @classmethod
    def create(cls, test_size=0.2, dataset_path=None, dataset_name="RML2016.10a", *args, **kwargs):
        # self.label = torch.tensor(labels, dtype=torch.long)
        cls.dataset_name = dataset_name
        if dataset_path is None:
            dataset_path = osp.join(osp.dirname(osp.abspath(__file__)), dataset_name)
        if dataset_name == "RML2016.10b":
            data = pd.read_pickle(osp.join(dataset_path, 'RML2016.10b.dat'))
            cls.classes = ['8PSK', 'BPSK', 'CPFSK', 'GFSK', 'PAM4', 'QAM16', 'QAM64', 'QPSK', 'AM-DSB', 'WBFM']
            cls.signal_length = 128
        elif dataset_name == "RML2016.04c":
            data = pd.read_pickle(osp.join(dataset_path, '2016.04C.multisnr.pkl'))
            cls.classes = ['8PSK', 'BPSK', 'CPFSK', 'GFSK', 'PAM4', 'QAM16', 'QAM64', 'QPSK', 'AM-DSB', 'AM-SSB', 'WBFM']
            cls.signal_length = 128
        else:
            data = pd.read_pickle(osp.join(dataset_path, 'RML2016.10a_dict.pkl'))
            cls.classes = ['8PSK', 'BPSK', 'CPFSK', 'GFSK', 'PAM4', 'QAM16', 'QAM64', 'QPSK', 'AM-DSB', 'AM-SSB', 'WBFM']
            cls.signal_length = 128


        cls.train_dataset = [[],[],[]]
        cls.val_dataset = [[],[],[]]
        cls.test_dataset = [[],[],[]]
        if "minSNR" in kwargs:
            minSNR = kwargs["minSNR"]
        else:
            minSNR = -20
        if "maxSNR" in kwargs:
            maxSNR = kwargs["maxSNR"]
        else:
            maxSNR = 18

        cls.minSNR = minSNR
        cls.maxSNR = maxSNR
        cls.SNR_list = range(minSNR, maxSNR + 1, 2)

        number = 0
        
        for item in data.items():
            (label, SNR), samples = item
            if SNR < minSNR or SNR > maxSNR or label not in cls.classes:
                continue
            labels = np.full(len(samples), cls.classes.index(label))
            number = number + len(samples)
            SNR = np.full(len(samples), SNR)
            X, x, Y, y, SNR_tr, SNR_te = train_test_split(samples, labels, SNR, test_size=test_size,
                                                        random_state=233,
                                                        stratify=labels)
            train, val, train_label, val_label, SNR_tr, SNR_va = train_test_split(X, Y, SNR_tr, test_size=0.25,
                                                                                random_state=233,
                                                                                stratify=Y)
            cls.train_dataset[0].extend(train)
            cls.train_dataset[1].extend(train_label)
            cls.train_dataset[2].extend(SNR_tr)
            cls.val_dataset[0].extend(val)
            cls.val_dataset[1].extend(val_label)
            cls.val_dataset[2].extend(SNR_va)
            cls.test_dataset[0].extend(x)
            cls.test_dataset[1].extend(y)
            cls.test_dataset[2].extend(SNR_te)
            cls.dataset = [cls.train_dataset, cls.val_dataset, cls.test_dataset]  

    def __init__(self, split="train", mode="default"):
        split_list = ["train", "valid", "test"]
        if not hasattr(RML2016, "train_dataset"):
            raise ValueError("The RML2016 dataset is not created, please use RML2016.create() to create instance.")
        if split not in split_list:
            raise ValueError(f"The split type {split} is not supported!")
        if mode not in ["default", "copy", "stft"]:
            raise ValueError(f"The mode type {mode} is not supported!")
        
        self.split_id = split_list.index(split)
        self.split = split
        self.mode = mode
        if mode == "stft":
            self.STFTs = []
            with tqdm(total=len(self.dataset[self.split_id][0])) as t:
                t.set_description('Generating STFT:')
                for _, IQ in enumerate(self.dataset[self.split_id][0]):
                    _, _, stp = stft(IQ[0,:], 1.0, 'blackman',31, 30, 128)
                    self.STFTs.append(np.expand_dims(stp[:32,:], 0))
                    t.update(1)

    def __len__(self):
        if self.split == "train":
            return len(self.train_dataset[0])
        elif self.split == "valid":
            return len(self.val_dataset[0])
        elif self.split == "test":
            return len(self.test_dataset[0])

    def __getitem__(self, idx):
        if self.mode == "default":
            return torch.Tensor(self.dataset[self.split_id][0][idx]),\
                torch.Tensor([]),\
                torch.tensor(self.dataset[self.split_id][1][idx], dtype=torch.long),\
                self.dataset[self.split_id][2][idx]
        elif self.mode == "copy":
            return torch.Tensor(self.dataset[self.split_id][0][idx], dtype=torch.float),\
                torch.tensor(self.dataset[self.split_id][1][idx], dtype=torch.long),\
                self.dataset[self.split_id][2][idx]
        elif self.mode == "stft":
            return torch.Tensor(self.STFTs[idx]),\
                torch.tensor(self.dataset[self.split_id][1][idx], dtype=torch.long),\
                self.dataset[self.split_id][2][idx]

    @property
    def get_pretrain_data(self):
        return np.array(self.train_dataset[0]), np.array(self.train_dataset[1]), np.array(self.train_dataset[2])
