import os
import multiprocessing
from pathlib import Path
import pandas as pd
from tqdm import tqdm
import torch
import numpy as np
import src.assets.simulator.multi_layer_model as real_simulator

NUM_LAY = 10
NUM_MAT = 7
# Path to data folder
PATH_TO_FOLDER = "/insert/here/absolute/path/to/dataset_10_layer/data"

materials_dict = {
    'SiO2': 0,
    'TiO2': 1,
    'SiC': 2,
    'MgF2': 3,
    'Al2O3': 4,
    'AlN': 5,
    'ZnO': 6,
#    'Ag': 7
}

def idx_to_mat(idx):
    for key, value in materials_dict.items():
        if idx == value:
            return key
    
    return None

def one_hot_encode(number, num_classes):
    one_hot = np.zeros(num_classes, dtype=np.float32)
    one_hot[number] = 1
    return one_hot

def one_hot_decode(one_hot):
    return np.argmax(one_hot)


def material_df_to_tensor(material_df):
    '''
    Input: a dataframe where each row is of kind layer_material_0  layer_thickness_0 .... layer_material_9  layer_thickness_9
    Output: A tensor NUM_MAT x (10 x 7  + 10) where rows 0..10 are one-hot encoding of the material and row 11 contains the thickness
    '''

    material_tensor = []
    for row in material_df.to_numpy():
        materials = row[0::2]
        #materials = np.append(materials, "Ag")
        thicknesses = row[1::2]
        #thicknesses = np.append(thicknesses, 0.1)

        row_tensor = []
        for mat in materials:
            row_tensor.append(one_hot_encode(materials_dict[mat], NUM_MAT))

        row_tensor = np.stack(row_tensor).flatten()
        row_tensor = np.concatenate([row_tensor, thicknesses.astype(np.float32)])

        material_tensor.append(row_tensor)

    return torch.tensor(np.array(material_tensor), dtype=torch.float32)


def spectra_df_to_tensor(values_df):
    spectra_tensor = []
    for row in values_df.to_numpy():
        spectra_tensor.append(row)

    return torch.tensor(np.array(spectra_tensor), dtype=torch.float32)

def material_tensor_to_lists(material):
    '''
    Input: A SINGLE material tensor of shape (10x7 + 10,)
    Output: a tuple of material and thickness: (mat = ['SiO2', 'SiC', ...], thic = [0.3449, 0.1894, ...])
    '''

    # 10 layers, 7 materials
    materials = material[:NUM_LAY * NUM_MAT].reshape((NUM_LAY, NUM_MAT))
    thicknesses = material[NUM_MAT * NUM_LAY:]

    mats = []
    for row in materials:
        idx = one_hot_decode(row)
        mats.append(idx_to_mat(idx))

    return mats, thicknesses.numpy()


def simulate_material(material, device='cuda'):
    mat, thickness = material_tensor_to_lists(material.cpu())
    spectra = real_simulator.simulate(list(mat), list(thickness))

    return torch.tensor(spectra).to(device)
    


def load_data(data_input):
    idx, param_fn, dataset_path = data_input
    suffix = param_fn.split("_")[-1]
    values_file_name = f"values_{suffix}"
    params_df = pd.read_csv(Path(dataset_path) / param_fn)
    values_df = pd.read_csv(Path(dataset_path) / values_file_name)

    return (idx, params_df, values_df)

def load_dataset(dataset_path):
    file_names = os.listdir(dataset_path)
    params_file_names = sorted([fn for fn in file_names if "params" in fn])
    print("loading data")

    data_inputs = [(i, pfn, dataset_path) for i, pfn in enumerate(params_file_names)]
    # ordered_inputs = [(p, o) for p, o in zip(params, range(len(params)))]
    with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
        results = list(tqdm(
            pool.imap_unordered(load_data, data_inputs), 
            total=len(params_file_names)))
        
    results = sorted(results, key=lambda x: x[0])
    params_df = [r[1] for r in results]
    values_df = [r[2] for r in results]
    # values = [x[0] for x in results]
    
    # for pfn in tqdm(params_file_names):
    #     suffix = pfn.split("_")[-1]
    #     values_file_name = f"values_{suffix}"
    #     params_df.append(pd.read_csv(Path(dataset_path) / pfn))
    #     values_df.append(pd.read_csv(Path(dataset_path) / values_file_name))

    params_df = pd.concat(params_df)
    values_df = pd.concat(values_df)
    params_df.drop(params_df.columns[0], axis=1, inplace=True)
    values_df.drop(values_df.columns[0], axis=1, inplace=True)
    #params = params_df.to_dict("records")
    
    #values = list(values_df.to_numpy())
    return params_df, values_df


def get_x_y_data(invd_steps = 1000, val_split = None, seed = 0, device = 'cpu'):
    params, values = load_dataset(PATH_TO_FOLDER)
    
    materials = material_df_to_tensor(params).to(device)
    spectra = spectra_df_to_tensor(values).to(device)
    num_sample = materials.shape[0]

    x_train = materials[:-invd_steps,:]
    y_train = spectra[:-invd_steps,:]

    x_test = materials[-invd_steps:,:]
    y_test = spectra[-invd_steps:,:]

    if val_split == None:
        return (x_train, y_train), None, (x_test, y_test)
    else:
        if val_split <= 0.0 or val_split >= 1.0:
            raise Exception("Invalid val_split. Expected [0 < val_split < 1]")
        
        remaining_sample = num_sample - invd_steps
        num_val = int(remaining_sample * val_split)
        num_train = remaining_sample - num_val

        torch.manual_seed(seed)
        indices = torch.randperm(remaining_sample)

        x_t = x_train[indices[:num_train]]
        y_t = y_train[indices[:num_train]]

        x_v = x_train[indices[-num_val:]]
        y_v = y_train[indices[-num_val:]]

        torch.seed()

        return (x_t, y_t), (x_v, y_v), (x_test, y_test)



if __name__ == "__main__":
    train_data, val_data, test_data = get_x_y_data(val_split = None, seed = 0)

    print(train_data[0])
    
    print(train_data[0].shape)
    print(val_data)
    print(test_data[0].shape)