from torch.utils.data import Dataset
import torch
import numpy as np 
import os 
import pandas as pd 
import json 


class MyDataset(Dataset):
    def __init__(self, TimeSeries):
        self.sequences = TimeSeries

    def __len__(self):
        return len(self.sequences[:,0, 0])

    def __getitem__(self, index):
        return self.sequences[index,:, :]


class SupervisedDatasetHierarchical(Dataset):

  def __init__(self, X_bot, X_top, y):

        self.X_bot = X_bot
        self.X_top = X_top
        self.y = y

  def __len__(self):

        return len(self.y)

  def __getitem__(self, index):

        # Select sample
        # Load data and get label

        return self.X_bot[index, :, :], self.X_top[index, :, :], self.y[index]



def pad_single_centroide(list_text_digit):

    nb_iter = len(list_text_digit)

    for ii in range(nb_iter):

        if len(list_text_digit[ii]) == 1:
            list_text_digit[ii] = '0' + list_text_digit[ii]


        else:
            pass

    return list_text_digit



def turn_array_into_text(array_centroids):

    list_indiv_text = list()

    for indiv_index in range(array_centroids.shape[0]):

        list_centroide_text = [str(ele) for ele in array_centroids[indiv_index,0,:]]
        text_with_space = ' '.join(pad_single_centroide(list_centroide_text))
        list_indiv_text.append(text_with_space)

    return list_indiv_text


def find_best_accu_eval(dataset_name, result_path):

    all_files_list = os.listdir(result_path)
    list_file_dataset_name = [file for file in all_files_list if dataset_name in file]
    list_of_df = [pd.read_csv(result_path + file) for file in list_file_dataset_name]

    list_test_for_max_eval = list()
    list_eval = list()

    for version, df in enumerate(list_of_df):

        Accu_eval = df['Accuracy_Eval'].values
        Accu_eval_argmax = np.argmax(Accu_eval)
        list_eval.append(np.amax(Accu_eval))
        list_test_for_max_eval.append(df['Accuracy_Test'][Accu_eval_argmax])

    version_max = np.argmax(np.array(list_eval))
    test_for_best_eval = list_test_for_max_eval[version_max]

    return test_for_best_eval, int(version_max + 1), 


def custom_analyzer_bis(text, lvl=2):

    words = text.split()
    for w in words:
        yield w

    if lvl > 1:
        bigrams = [b for b in zip(words[:-1], words[1:])]
        for bi in bigrams:
            yield bi[0] + bi[1]

    else:
        pass

    if lvl > 2:
        trigrams = [b for b in zip(words[:-2], words[1:-1], words[2:])]
        for tri in trigrams:
            yield tri[0] + tri[1] + tri[2]

    else:
        pass






def compute_conv1D_sequence_out_length(L_in, padding, dilat, k_s, stride):

    """
    return l_out length after conv1D
    """

    return (L_in + 2 * padding - dilat * (k_s - 1) - 1) / stride + 1




def compute_short_conv1D_sequence_out_length(L_in, k_s):

    """
    return l_out length after conv1D with padding=1 and stride=2
    """

    return (L_in + 2 - k_s ) / 2 + 1




def compute_deconv1D_sequence_out_length(L_in, padding, dilat, k_s, stride, output_padding):

    """
    return l_out length after conv1D with padding=1 and stride=2
    """

    return (L_in - 1) * stride - 2 * padding + dilat * (k_s - 1) + output_padding + 1




def compute_k_s_list_encoder(L_in, scale_reduc):

    nb_steps = int(scale_reduc)
    L_in_list = list()
    L_in_list.append(L_in)
    k_s_list = list()

    for step_idx in range(nb_steps):

        if L_in % 2 == 0:
            k_s = 4

        else:
            k_s = 3

        L_in = compute_short_conv1D_sequence_out_length(L_in, k_s)
        L_in_list.append(L_in)
        k_s_list.append(k_s)

    return L_in_list, k_s_list



def compute_decod_parameters(L_in, L_target, scale_reduc):


    L_tempo = (L_in - 1) * scale_reduc 

    if L_target - L_tempo < 3:
        padding = 1 
        k_s = L_target - (L_tempo - 2 * padding)

    else:
        padding = 0
        k_s = L_target - L_tempo

    #Check if it is right
    supposed_length = compute_deconv1D_sequence_out_length(L_in, padding, 
                                                            dilat=1, k_s=k_s, 
                                                            stride=scale_reduc, 
                                                            output_padding=0) 
    if supposed_length == L_target:
        pass
    else:
        print("Error")


    return [int(k_s)], [int(scale_reduc)], [int(padding)]



def find_params_D_VAE(sequence, scale_reduc_bot):


    seq_length = sequence.shape[-1]

    # Compute encoder bottom k_s list
    L_in_list_bot, k_s_list_encod_b = compute_k_s_list_encoder(seq_length, scale_reduc_bot)

    params = {'k_s_list_encod_b': k_s_list_encod_b,        
             'seq_len': seq_length,
             'seq_bot_len': int(L_in_list_bot[-1])}

    return params


def find_params_VQ_AE(sequence, scale_reduc_bot, scale_reduc_mid):

    """
    @sequence of shape (batch_size, nb_channels, seq_length)
    @scale_reduc_bot : 2 or 4
    @scale_reduc_mid : 2 or 4

    --------------------------------------------------------
    return params
    """

    seq_length = sequence.shape[-1]

    # Compute encoder bottom k_s list
    L_in_list_bot, k_s_list_encod_b = compute_k_s_list_encoder(seq_length, scale_reduc_bot)

    # Compute encoder middle k_s list
    L_in_list_mid, k_s_list_encod_m = compute_k_s_list_encoder(L_in_list_bot[-1], scale_reduc_mid)

   
    params = {'k_s_list_encod_b': k_s_list_encod_b,
             'k_s_list_encod_m': k_s_list_encod_m,     
             'seq_len': seq_length,
             'seq_bot_len': int(L_in_list_bot[-1]),
             'seq_mid_len': int(L_in_list_mid[-1])}

    return params


def compute_receptive_field(list_ks, list_strides):

    nb_layers = len(list_ks)
    somme = 0

    for ii in range(nb_layers):

        if ii == 0:
            prod_strides = 1
        else:
            prod_strides = np.prod(np.array(list_strides[:ii]))

        incre = (list_ks[ii] - 1) * prod_strides

        somme += incre

    return somme + 1


def compute_receptive_field_region(list_ks, list_strides, list_padding, uL, vL):

    nb_layers = len(list_ks)

    u0_part1 = uL * np.prod(np.array(list_strides))
    v0_part1 = vL * np.prod(np.array(list_strides))

    somme_u = 0 
    somme_v = 0

    for ii in range(nb_layers):

        if ii == 0:
            prod_strides = 1
        else:
            prod_strides = np.prod(np.array(list_strides[:ii]))

        incre_u = list_padding[ii] * prod_strides
        incre_v = (1 + list_padding[ii] - list_ks[ii])* prod_strides

        somme_u += incre_u
        somme_v += incre_v

    u0 = u0_part1 - somme_u
    v0 = v0_part1 - somme_v

    return u0, v0





 
