import os
import random
import numpy as np
import pandas as pd
import os.path as osp
import torch
from torch.utils.data import Dataset
from scipy.signal import stft
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from . import register_dataset
from .base_dataset import BaseDataset

def load_bin_file(file_path, segment_length=1024):
    # Read binary file as short integers
    with open(file_path, 'rb') as f:
        raw_data = np.fromfile(f, dtype=np.float32)
    # Reshape into (2, N) array for I and Q channels
    raw_data = raw_data.reshape(2, -1, order='F')  # F order to match MATLAB reshape
    
    # Convert to complex samples
    i_samples = raw_data[0]
    q_samples = raw_data[1]
    
    num_complete_segments = len(i_samples) // segment_length
    
    # Initialize list to store IQ segments
    iq_segments = []

    # Create segments
    for i in range(num_complete_segments):
        start_idx = i * segment_length
        end_idx = start_idx + segment_length
        
        segment = np.stack([
            i_samples[start_idx:end_idx],
            q_samples[start_idx:end_idx]
        ])
        
        iq_segments.append(segment)
    
    return iq_segments

def add_noise_awgn(x, snr):
    x = np.array(x)
    signal_power = np.mean(x ** 2, axis=(1, 2), keepdims=True)   # (B,1,1)
    noise_power = signal_power / (10 ** (snr / 10.0))          # (B,1,1)
    noise = np.random.normal(0, 1, x.shape).astype(x.dtype)
    noise *= np.sqrt(noise_power / 2)
    return (x + noise).tolist()

@register_dataset("Techrec")
class Techrec(BaseDataset):
    @classmethod
    def create(cls, test_size=0.2, dataset_path=None, *args, **kwargs):
        # self.label = torch.tensor(labels, dtype=torch.long)
        cls.classes = ['lte', 'wf', 'dvbt']

        if "minSNR" in kwargs:
            minSNR = kwargs["minSNR"]
        else:
            minSNR = -15
        if "maxSNR" in kwargs:
            maxSNR = kwargs["maxSNR"]
        else:
            maxSNR = 20
        cls.minSNR = minSNR
        cls.maxSNR = maxSNR
        cls.SNR_list = range(minSNR, maxSNR + 1, 5)

        if dataset_path is None:
            dataset_path = osp.join(osp.dirname(osp.abspath(__file__)), "Techrec")
        if osp.exists(osp.join(dataset_path, "Techrec.npz")):
            data = np.load(osp.join(dataset_path, "Techrec.npz"))
            samples = data["samples"]
            SNR = data["SNR"]
            labels = data["labels"]
            data.close()
        else:
            dataset_info = []
            for root, dirs, files in os.walk(dataset_path):
                for file in files:
                    if file.endswith('.bin'):
                        parts = file.split('_')
                        if 'Msps' in parts[0]:
                            signal_type = ''.join([i for i in parts[0] if not i.isdigit() and i != 'M' and i != 's' and i != 'p'])
                            sampling_rate = int(''.join([i for i in parts[0] if i.isdigit()])) * 1e6
                        else:
                            signal_type = parts[0]
                            sampling_rate = 1e6
                        usrp = int(parts[1][1:])
                        location = root.split('/')[-1]
                        center_frequency = parts[3][1:-4] + 'MHz'
                        
                        dataset_info.append({
                            'signal_type': signal_type,
                            'sampling_rate': sampling_rate,
                            'usrp': usrp,
                            'location': location,
                            'center_frequency': center_frequency,
                            'file_path': os.path.join(root, file)
                        })

            IQ_data_list = []
            signal_type_label_list = []
            center_frequency_list = []
            location_list = []
            for sample in dataset_info:
                IQ_clips = load_bin_file(sample['file_path'])
                for IQ_clip in IQ_clips:
                    IQ_data_list.append(IQ_clip)
                    signal_type_label_list.append(cls.classes.index(sample['signal_type']))
                    center_frequency_list.append(sample['center_frequency'])
                    location_list.append(sample['location'])

            samples = []
            labels = []
            SNR = []
            num_snr = len(cls.SNR_list)
            N = len(IQ_data_list)
            k = N // num_snr
            combined = list(zip(IQ_data_list, signal_type_label_list))
            random.shuffle(combined)
            shuffled_IQ, shuffled_label = zip(*combined)
            shuffled_IQ = np.array(shuffled_IQ)
            shuffled_label = np.array(shuffled_label)

            for i, snr in enumerate(range(minSNR, maxSNR + 1, 5)):
                if (i + 1) == num_snr:
                    IQ_data = shuffled_IQ[i * k:]
                    label = shuffled_label[i * k:]
                else:
                    IQ_data = shuffled_IQ[i * k: (i + 1) * k]
                    label = shuffled_label[i * k: (i + 1) * k]
                SNR.extend([snr] * len(label))
                noised_sample = add_noise_awgn(IQ_data, snr)
                samples.extend(noised_sample)
                labels.extend(label)

            np.savez_compressed(osp.join(dataset_path, "Techrec.npz"),
                    samples=samples,
                    SNR=SNR,
                    labels=labels)

        X, x, Y, y, SNR_tr, SNR_te = train_test_split(
                                    samples, 
                                    labels, 
                                    SNR, 
                                    test_size=test_size,
                                    random_state=233,
                                    stratify=labels)

        cls.train_dataset = [[],[],[]]
        cls.val_dataset = [[],[],[]]
        cls.test_dataset = [[],[],[]]

        train, val, train_label, val_label, SNR_tr, SNR_va = train_test_split(X, Y, SNR_tr, test_size=0.25,
                                                                            random_state=233,
                                                                            stratify=Y)
        cls.train_dataset[0].extend(train)
        cls.train_dataset[1].extend(train_label)
        cls.train_dataset[2].extend(SNR_tr)
        cls.val_dataset[0].extend(val)
        cls.val_dataset[1].extend(val_label)
        cls.val_dataset[2].extend(SNR_va)
        cls.test_dataset[0].extend(x)
        cls.test_dataset[1].extend(y)
        cls.test_dataset[2].extend(SNR_te)
        cls.dataset = [cls.train_dataset, cls.val_dataset, cls.test_dataset]

    def __init__(self, split="train", mode="default"):
        split_list = ["train", "valid", "test"]
        if not hasattr(Techrec, "train_dataset"):
            raise ValueError("The Techrec dataset is not created, please use Techrec.create() to create instance.")
        if split not in split_list:
            raise ValueError(f"The split type {split} is not supported!")
        if mode not in ["default", "copy", "stft"]:
            raise ValueError(f"The mode type {mode} is not supported!")
        
        self.split_id = split_list.index(split)
        self.split = split
        self.mode = mode
        if mode == "stft":
            self.STFTs = []
            with tqdm(total=len(self.dataset[self.split_id][0])) as t:
                t.set_description('Generating STFT:')
                for _, IQ in enumerate(self.dataset[self.split_id][0]):
                    _, _, stp = stft(IQ[0,:], 1.0, 'blackman',31, 30, 128)
                    self.STFTs.append(np.expand_dims(stp[:32,:], 0))
                    t.update(1)

    def __len__(self):
        if self.split == "train":
            return len(self.train_dataset[0])
        elif self.split == "valid":
            return len(self.val_dataset[0])
        elif self.split == "test":
            return len(self.test_dataset[0])

    def __getitem__(self, idx):
        if self.mode == "default":
            return torch.Tensor(self.dataset[self.split_id][0][idx]),\
                torch.Tensor([]),\
                torch.tensor(self.dataset[self.split_id][1][idx], dtype=torch.long),\
                self.dataset[self.split_id][2][idx]
        elif self.mode == "copy":
            return torch.Tensor(self.dataset[self.split_id][0][idx], dtype=torch.float),\
                torch.tensor(self.dataset[self.split_id][1][idx], dtype=torch.long),\
                self.dataset[self.split_id][2][idx]
        elif self.mode == "stft":
            return torch.Tensor(self.STFTs[idx]),\
                torch.tensor(self.dataset[self.split_id][1][idx], dtype=torch.long),\
                self.dataset[self.split_id][2][idx]

    @property
    def get_pretrain_data(self):
        return np.array(self.train_dataset[0]), np.array(self.train_dataset[1]), np.array(self.train_dataset[2])
