import sys
import numpy as np
import os
import random
import torch
from scipy.io import loadmat

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from tools.data_loader_mis import Standard_mat_DataLoader
from tools.data_loader_mis import SP_DataLoader


def data_preparation(data_name,
                     fidelity_num,
                     seed,
                     train_samples_num):
    SP_DataLoader_available = ['plasmonic2_MF']

    Standard_mat_DataLoader_available = [
                            'Burget_mfGent_v5_15',
                            'Heat_mfGent_v5',
                            'Poisson_mfGent_v5',
                            'TopOP_mfGent_v6',
                            ]
    

    if data_name in Standard_mat_DataLoader_available:
        mat_data = Standard_mat_DataLoader(data_name, True)
        xxtr, xytr, xte, yte = mat_data.get_data()
        random.seed(seed)
        ind = [random.randint(0,train_samples_num-1) for i in range(train_samples_num)] # generating the index of data for training
        xtr = [torch.stack([xxtr[0][j] for j in ind])]
        ytr = []
        for i in range(len(xytr)):
            ytr.append(torch.stack([xytr[i][j] for j in ind]))
    elif data_name in SP_DataLoader_available:
        mat_data = SP_DataLoader(data_name, None)
        xxtr, xytr, xte, yte = mat_data.get_data()
        random.seed(seed)
        ind = [random.randint(0,train_samples_num-1) for i in range(train_samples_num)] # generating the index of data for training
        if data_name == "TopOP_mfGent_v6":
            xtr = [torch.stack([torch.tensor(xxtr[0][j]) for j in ind])]
            ytr = []
            for i in range(len(xytr)):
                ytr.append(torch.stack([torch.tensor(xytr[i][j]) for j in ind]))
        else:
            xtr = [torch.stack([xxtr[0][j] for j in ind])]
            ytr = []
            for i in range(len(xytr)):
                ytr.append(torch.stack([xytr[i][j] for j in ind]))

    return xtr, ytr, xte, yte