import os
import torch
from torch.utils.data import Dataset
import pandas as pd
import scipy.io as sio
import numpy as np

class Data:
    def __init__(self, root_dir: str, dataset: str, series_id: int = 1, subject=0):
        self.root_dir = root_dir
        self.dataset = dataset
        self.series_id = series_id
        # subject for fMRI dataset
        # work like trials for simulated data
        self.subject = subject # this is only for fMRI dataset, subject=-1 means load all subjects' data
        self.data_path = os.path.join(self.root_dir, dataset, 'data')
        self.label_path = os.path.join(self.root_dir, dataset, 'label')
        if dataset in ['DREAM3', 'DREAM4']:
            assert 1 <= series_id <= 5, "series_id must be in {1,2,3,4,5} for DREAM datasets"
        elif dataset == 'CausalTime':
            assert 1 <= series_id <= 3, "series_id must be in {1,2,3} for CausalTime dataset"
        elif dataset == 'fMRI':
            assert 1 <= series_id <= 2, "series_id must be in {1,2} for fMRI dataset"
        elif dataset == 'VAR3':
            assert 1 <= series_id <= 4, "series_id must be in {1,2,3,4} for VAR dataset"
        elif dataset == 'GLV':
            assert 1 <= series_id <= 4, "series_id must be in {1,2,3,4} for GLV dataset"
        elif dataset == 'Lorenz96':
            assert 1 <= series_id <= 8, "series_id must be in {1,2,3,4,5,6,7,8} for Lorenz96 dataset"
        else:
            raise ValueError("dataset must be either 'DREAM3', 'DREAM4', 'CausalTime', 'fMRI', 'VAR3' or 'Lorenz96'")

    def load_data(self):
        if self.dataset == 'DREAM3':
            series_names = ['Ecoli1', 'Ecoli2', 'Yeast1', 'Yeast2', 'Yeast3']
            series_name = series_names[self.series_id - 1]
            X, var_names = self._load_dream3_data(series_name)
            graph = self._load_dream3_label(series_name, var_names)
        elif self.dataset == 'DREAM4':
            series_names = ['Gene_1', 'Gene_2', 'Gene_3', 'Gene_4', 'Gene_5']
            series_name = series_names[self.series_id - 1]
            X, var_names = self._load_dream4_data(series_name)
            graph = self._load_dream4_label(series_name, var_names)
        elif self.dataset == 'CausalTime':
            series_names = ['medical', 'pm25', 'traffic']
            series_name = series_names[self.series_id - 1]
            X, var_names = self._load_causaltime_data(series_name)
            graph = self._load_causaltime_label(series_name)
        elif self.dataset == 'fMRI':
            series_names = ['fMRI_15', 'fMRI_50']
            series_name = series_names[self.series_id - 1]
            X, var_names, graph = self._load_fmri(series_name, self.subject)
        elif self.dataset == 'VAR3':
            series_names = ['D10_T500', 'D10_T1000', 'D50_T500', 'D50_T1000']
            series_name = series_names[self.series_id - 1]
            X, var_names = self._load_simulated_data(series_name)
            graph = self._load_simulated_label(series_name)
        elif self.dataset == 'Lorenz96':
            series_names = ['D10_T500_F50', 'D10_T500_F10', 'D10_T1000_F50', 'D10_T1000_F10', 'D50_T500_F50', 'D50_T500_F10', 'D50_T1000_F50', 'D50_T1000_F10']
            series_name = series_names[self.series_id - 1]
            X, var_names = self._load_simulated_data(series_name)
            graph = self._load_simulated_label(series_name)
        elif self.dataset == 'GLV':
            series_names = ['D10_T500', 'D10_T1000', 'D50_T500', 'D50_T1000']
            series_name = series_names[self.series_id - 1]
            X, var_names = self._load_simulated_data(series_name)
            graph = self._load_simulated_label(series_name)
        else:
            raise ValueError("dataset must be either 'DREAM3', 'DREAM4', 'CausalTime', 'fMRI', 'VAR3' or 'Lorenz96'")
        
        print(f"Loading data for dataset: {self.dataset}, series: {series_name} (ID: {self.series_id})")

        # for graph, entry (i, j) indicates edge j -> i
        # in the data file, entry (source, target) indicates source -> target
        return X, graph, var_names

    def _load_dream3_data(self, series_name):
        file_path = os.path.join(self.data_path, f"{series_name}-trajectories.tsv")
        df = pd.read_csv(file_path, sep='\t')
        var_names = [c for c in df.columns if c != 'Time']
        # with known time point to be 21
        T = 21
        X = torch.tensor(df[var_names].values, dtype=torch.float32).reshape(-1, T, len(var_names))

        return X, var_names

    def _load_dream3_label(self, series_name, var_names):
        file_path = os.path.join(self.label_path, f"{series_name}.txt")
        df = pd.read_csv(file_path, sep='\t', header=None)
        df.columns = ['source', 'target', 'label']

        n_genes = len(var_names)
        name_to_idx = {g: i for i, g in enumerate(var_names)}
        graph = torch.zeros((n_genes, n_genes), dtype=torch.float32)
        for src, tgt, val in df.itertuples(index=False, name=None):
            if src in name_to_idx and tgt in name_to_idx:
                graph[name_to_idx[tgt], name_to_idx[src]] = float(val)
        return graph

    def _load_dream4_data(self, series_name):
        file_path = os.path.join(self.data_path, f"{series_name}.mat")
        mat = sio.loadmat(file_path)
        key = [k for k in mat.keys() if not k.startswith("__")][0]
        X = torch.tensor(mat[key], dtype=torch.float32)
        var_names = [f"G{i+1}" for i in range(X.shape[1])]
        # with known time point to be 21
        T = 21
        X = X.reshape(-1, T, len(var_names))

        return X, var_names

    def _load_dream4_label(self, series_name, var_names):
        file_path = os.path.join(self.label_path, f"{series_name}.tsv")
        df = pd.read_csv(file_path, sep='\t', header=None)
        df.columns = ['source', 'target', 'label']

        n_genes = len(var_names)
        name_to_idx = {g: i for i, g in enumerate(var_names)}
        graph = torch.zeros((n_genes, n_genes), dtype=torch.float32)
        for src, tgt, val in df.itertuples(index=False, name=None):
            if src in name_to_idx and tgt in name_to_idx:
                graph[name_to_idx[tgt], name_to_idx[src]] = float(val)
        return graph

    def _load_causaltime_data(self, series_name):
        file_path = os.path.join(self.data_path, f"{series_name}.npy")
        gen_data = np.load(file_path)
        all_dim = gen_data.shape[2]
        # first half are causal variables, second half are residual variables
        X = torch.tensor(gen_data[:, :, :all_dim//2], dtype=torch.float32)
        var_names = [f"V{i+1}" for i in range(X.shape[2])]

        return X, var_names
    
    def _load_causaltime_label(self, series_name):
        file_path = os.path.join(self.label_path, f"{series_name}.npy")
        gen_graph = np.load(file_path)
        graph = torch.tensor(gen_graph, dtype=torch.float32)
        # graph_ji = 1 means j causes i
        # under our setting, we want graph[i, j] = 1 if j causes i
        graph = graph.T
        return graph

    def _load_fmri(self, series_name, subject):
        file_path = os.path.join(self.data_path, f"{series_name}.mat")
        mat = sio.loadmat(file_path)

        time_len = mat['Ntimepoints'][0][0]
        nodes = mat['Nnodes'][0][0]
        ts = mat['ts']
        
        var_names = [f"ROI_{i+1}" for i in range(nodes)]
        
        n_subjects = 50 # there are 50 subjects in fMRI dataset
        all_subjects = torch.tensor(ts, dtype=torch.float32).reshape(n_subjects, time_len, -1)
        graph = mat['net'][0]
        assert np.allclose(torch.tensor((mat['net'][0] != 0).astype(np.int32), dtype=torch.int32), torch.tensor((mat['net'][1] != 0).astype(np.int32), dtype=torch.int32)), "Graphs for different subjects are not the same!"
        graph = torch.tensor((graph != 0).astype(np.int32), dtype=torch.int32).T

        if subject == -1:
            return all_subjects, var_names, graph
        else:
            return all_subjects[subject, :, :], var_names, graph

    def _load_simulated_data(self, series_name):
        file_path = os.path.join(self.data_path, f"{series_name}_trials_{self.subject}.npy")
        gen_data = np.load(file_path)
        X = torch.tensor(gen_data, dtype=torch.float32)
        var_names = [f"V{i+1}" for i in range(X.shape[1])]

        return X, var_names

    def _load_simulated_label(self, series_name):
        file_path = os.path.join(self.label_path, f"{series_name}_trials_{self.subject}.npy")
        gen_graph = np.load(file_path)
        graph = torch.tensor(gen_graph, dtype=torch.int32)
        return graph

class TimeSeriesDataset(Dataset):
    def __init__(self, X, lag: int, Norm=False, device='cpu'):
        # X: [T, d]
        assert lag >= 1, "lag must be >= 1"
        
        # ---- 1. Convert numpy to torch if needed ----
        if isinstance(X, np.ndarray):
            # Convert to float tensor; adjust dtype if you prefer
            X = torch.from_numpy(X).float()
        elif isinstance(X, torch.Tensor):
            pass
        else:
            raise TypeError("X must be a numpy.ndarray or torch.Tensor")
        
        if X.ndim == 2:
            X = X.unsqueeze(0)
        elif X.ndim == 3:
            pass
        else:
            raise ValueError('X must have shape [T, d] or [B, T, d]')
        
        self.batch_size, self.T, self.dim = X.shape
        self.device = device
        self.lag = lag
        
        # Normalize each series (over T) individually
        if Norm:
            mean = X.mean(dim=1, keepdim=True)
            std = X.std(dim=1, keepdim=True)
            self.norm_X = (X - mean) / std # [B, T, d]
        else:
            self.norm_X = X
        
        self.inputs, self.outputs = self._create_lagged_sequences(self.norm_X)
        
        self.input_dim = self.inputs.shape[1]
        self.output_dim = self.outputs.shape[1]
        
    def split_lag(self, split_ratio=0.7):
        train_size = int(split_ratio * self.T)
        val_size = self.T - train_size
        train_data = self.norm_X[:, :train_size, :]
        val_data = self.norm_X[:, -(val_size + self.lag):, :]

        train_X, train_Y = self._create_lagged_sequences(train_data)
        val_X, val_Y = self._create_lagged_sequences(val_data)
        
        train_dataset = torch.utils.data.TensorDataset(train_X, train_Y)
        val_dataset = torch.utils.data.TensorDataset(val_X, val_Y)
        
        return train_dataset, val_dataset
    
    def split_series(self, split_ratio=0.7):
        if self.batch_size == 1: # no replicates
            return self.split_lag(split_ratio)
        train_size = int(split_ratio * self.batch_size)
        train_data = self.norm_X[:train_size, :, :]
        val_data = self.norm_X[train_size:, :, :]

        train_X, train_Y = self._create_lagged_sequences(train_data)
        val_X, val_Y = self._create_lagged_sequences(val_data)
        
        train_dataset = torch.utils.data.TensorDataset(train_X, train_Y)
        val_dataset = torch.utils.data.TensorDataset(val_X, val_Y)
        
        return train_dataset, val_dataset

    def _create_lagged_sequences(self, X):
        # X_unfold = X[:, :-1].unfold(dimension=1, size=self.lag, step=1).contiguous()
        # inputs = X_unfold.view(-1, self.lag * self.dim)
        # outputs = X[:, self.lag:, :].contiguous().view(-1, self.dim)
        batch_size, T, dim = X.shape

        inputs = torch.cat([X[:, t:t+self.lag, :] for t in range(T - self.lag)], dim=1)
        inputs = inputs.view(batch_size, T - self.lag, self.lag * dim).reshape(-1, self.lag * dim).contiguous()
        outputs = X[:, self.lag:, :].reshape(-1, dim).contiguous()
        
        inputs = inputs.to(self.device).requires_grad_(True)
        outputs = outputs.to(self.device)

        return inputs, outputs

    def __len__(self):
        return self.inputs.shape[0]

    def __getitem__(self, idx):
        return self.inputs[idx], self.outputs[idx]

if __name__ == "__main__":
    ################### TEST FOR DREAM3 ###################
    root_dir = './data'
    dataset = 'DREAM3'
    series_id = 1
    data_extractor = Data(root_dir, dataset, series_id)
    X, graph, var_names = data_extractor.load_data()
    print(X.shape)
    
    ################### TEST FOR DREAM4 ###################
    root_dir = './data'
    dataset = 'DREAM4'
    series_id = 1
    data_extractor = Data(root_dir, dataset, series_id)
    X, graph, var_names = data_extractor.load_data()
    print(X.shape)
    
    ################# TEST FOR CausalTime #################
    root_dir = './data'
    dataset = 'CausalTime'
    series_id = 1  # pm25
    data_extractor = Data(root_dir, dataset, series_id)
    X, graph, var_names = data_extractor.load_data()
    print(X.shape)
    
    root_dir = './data'
    dataset = 'CausalTime'
    series_id = 2  # pm25
    data_extractor = Data(root_dir, dataset, series_id)
    X, graph, var_names = data_extractor.load_data()
    print(X.shape)
    
    root_dir = './data'
    dataset = 'CausalTime'
    series_id = 3  # pm25
    data_extractor = Data(root_dir, dataset, series_id)
    X, graph, var_names = data_extractor.load_data()
    print(X.shape)
    
    #################### TEST FOR fMRI ####################
    # For all fMRI subjects, share the same graph
    root_dir = './data'
    dataset = 'fMRI'
    series_id = 1  # fMRI_15
    subject = -1
    data_extractor = Data(root_dir, dataset, series_id, subject)
    all_subject, graph_1, var_names = data_extractor.load_data()
    print(all_subject.shape)
    
    ################# TEST FOR SIMULATION #################
    root_dir = './data'
    for dataset in ['VAR3', 'Lorenz96']:
        for subject in range(1, 6):
            for series_id in range(1, 5 if dataset == 'VAR3' else 9):
                if dataset == 'VAR3' and series_id > 4:
                    continue
                print(f"Testing dataset: {dataset}, series_id: {series_id}, subject: {subject}")
                root_dir = './data'
                dataset = dataset
                subject = subject
                data_extractor = Data(root_dir, dataset, series_id, subject)
                X, graph, var_names = data_extractor.load_data()
                print(X[:3, 0])
    
    ################### TEST FOR DATASET ##################
    # batch: 2, time length: 5, dim: 2
    batch_size = 10
    T = 4
    d = 2
    lag = 1
    
    data = torch.arange(batch_size * T * d).float().reshape(batch_size, T, d)
    dataset = TimeSeriesDataset(X=data, lag=lag, Norm=False)
    # train_ds, val_ds, test_ds = dataset.sample_splitting(train_ratio=0.6, val_ratio=0.2)
    
    def print_dataset_shapes(ds, name):
        X, Y = ds.tensors
        print(f"{name} X: {X.shape}, Y: {Y.shape}")

    print_dataset_shapes(train_ds, "Train")
    print_dataset_shapes(val_ds, "Validation")
    print_dataset_shapes(test_ds, "Test")
