import h5py
import numpy as np
from pathlib import Path
import scipy.io as sio
import os
import torch
from torch.utils.data import DataLoader

class h5Dataset:
    def __init__(self, path:Path, name:str,mode:str='a') -> None:
        self.__name = name
        if mode !='a' and mode !='r':
            raise Exception(f'can not set mode to {mode}, only "a" or "r"')
        self.__f = h5py.File(path / f'{name}.hdf5', mode)
        
    def get_group_names(self):
        return list(self.__f.keys())

    def get_dataset_names_from_group(self,grpName:h5py.Group):
        return list(self.__f[grpName].keys())

    def get_group(self,grpName:h5py.Group):
        return self.__f[grpName]

    def get_dataset_from_group(self,grpName:h5py.Group,dsName:h5py.Dataset):
        return self.__f[grpName][dsName]
    
    def addGroup(self, grpName:str):
        print(self.get_group_names())
        return self.__f.create_group(grpName)
    
    def addDataset(self, grp:h5py.Group, dsName:str, arr:np.array, chunks:tuple):
        if chunks is not None:
            return grp.create_dataset(dsName, data=arr, chunks=chunks)
        else:
            return grp.create_dataset(dsName, data=arr)

    def addAttributes(self, src:'h5py.Dataset|h5py.Group', attrName:str, attrValue):
        src.attrs[f'{attrName}'] = attrValue


    def save(self):
        self.__f.close()
    
    @property
    def name(self):
        return self.__name


standard_channel_list = ['Fp1', 'Fpz', 'Fp2', 'Fp9', 'Fp10', 'Nz', 'AF1', 'AF2', 'AFz', 'AF3', 'AF4', 'AF5', 'AF6', 
                         'AF7', 'AF8', 'AF9', 'AF10', 'F1', 'F2', 'Fz', 'F3', 'F4', 'F5', 'F6', 'F7', 'F8', 'F9', 'F10',
                         'FC1', 'FC2', 'FCz', 'FC3', 'FC4', 'FC5', 'FC6', 'FT7', 'FT8', 'FT9', 'FT10', 'C1', 'C2', 'Cz',
                         'C3', 'C4', 'C5', 'C6', 'T7', 'T8', 'T9', 'T10', 'I1', 'I2', 'CP1', 'CP2', 'CPz', 'CP3', 'CP4',
                         'CP5', 'CP6', 'TP7', 'TP8', 'TP9', 'TP10', 'P1', 'P2', 'Pz', 'P3', 'P4', 'P5', 'P6', 'P7', 'P8',
                         'P9', 'P10', 'PO1', 'PO2', 'POz', 'PO3', 'PO4', 'PO5', 'PO6', 'PO7', 'PO8', 'PO9', 'PO10',
                         'O1', 'O2', 'Oz', 'O9', 'O10', 'Iz', 'CB1', 'CB2', 'A1', 'A2']

if __name__ == '__main__':
    h5_path = "../data/all_data.hdf5"
    print(len(standard_channel_list))
    standard_channel_list = [s.upper() for s in standard_channel_list]
    all_channel_embedding = torch.eye(len(standard_channel_list), requires_grad=False).cuda()

    save_path = Path('../data')
    h5_dataset = h5Dataset(save_path, 'new_U_data_order')

    with h5py.File(h5_path, 'r') as f:
        print(list(f.keys()))  # 查看 HDF5 文件中的所有数据集（Group）
        keys = list(f.keys())
        
        for key in keys:
            if key == 'seed':
                continue
            dataset = f[key]['data']
            print(key)
            print("dataset:", dataset.shape)
            X = []
            channels = f[key]['data'].attrs['channels']
            channels = [s.upper() for s in channels]

            reordered_index_map = [channels.index(ch) for ch in standard_channel_list if ch in channels]
            # print(reordered_index_map)
            new_channels = [ch for ch in standard_channel_list if ch in channels]

            channel_index =  [standard_channel_list.index(item) for item in new_channels]
            channel_index = torch.tensor(list(channel_index))
            # print(channel_index)

            group = h5_dataset.addGroup(key)
            
            data_loader = DataLoader(dataset, batch_size=2048, drop_last=True)
            for iter, eeg in enumerate(data_loader):
                eeg = eeg.float().cuda()
                new_eeg = eeg[:, reordered_index_map, :]
                # print(eeg.shape)
                _, S, VT = torch.linalg.svd(new_eeg, full_matrices=True)
                VT = VT[:, :new_eeg.shape[1], :]
                VT_inverse = VT.transpose(2, 1)
                channel_embedding = all_channel_embedding[channel_index]
                channel_embedding = channel_embedding.unsqueeze(0).expand(new_eeg.shape[0], -1, -1)
                right_channel_embedding = channel_embedding.cuda()
                left_channel_embedding = channel_embedding.transpose(2, 1).cuda()
                U = torch.matmul(new_eeg, VT_inverse)
                U = torch.matmul(left_channel_embedding, U)
                U = torch.matmul(U, right_channel_embedding)
                # U = U.unsqueeze(1)
                X.append(U)
            X = torch.cat(X, dim=0).cpu().numpy()
            ds = h5_dataset.addDataset(group, 'data', X, chunks=(1, 95, 95))
            print(X.shape)
        h5_dataset.save()