import os
import math
import glob
import json

from scipy.io import loadmat
from scipy.signal import resample
from scipy.interpolate import interp1d

import numpy as np
import random
import pandas as pd

import torch.utils.data as data
from torch.utils.data import DataLoader


def normalize(data, minn, maxx):
    """Normalize a quantity using global minima and maxima.

    Args:
        data (np.array): Electrical motor quantity as np.array.
        quantity (str): Name of the quantity

    Returns:
        np.array: Normalized electrical motor quantity.

    Raises:        ExceptionName: Why the exception is raised.

    Examples
        Examples should be written in doctest format, and
        should illustrate how to use the function/class.
        >>>

    """
    # a = 0
    # b = 1
    # t = a + (data - minn) * ((b - a) / (maxx - minn))
    t = data / maxx
    return t.astype(np.float32)


def denormalize(data, minn, maxx):
    """Denormalize a quantity using global minima and maxima.

    Args:
        data (np.array): Normalized electrical motor quantity as np.array.
        quantity (str): Name of the quantity

    Returns:
        np.array: Denormalized electrical motor quantity.

    Raises:        ExceptionName: Why the exception is raised.

    Examples
        Examples should be written in doctest format, and
        should illustrate how to use the function/class.
        >>>

    """
    # t = minn + (data - (0)) * ((maxx - minn) / (1 - (0)))
    t = data * maxx
    return t.astype(np.float32)


def get_samples(pid, length, opt):
    samples = []
    for i in range(0, length, opt.stride):
        if i + opt.window < length:
            samples.append([i, i + opt.window, pid])

    return samples 


def load_temperature_data(opt):
    train_df = pd.read_csv(os.path.join(opt.data_dir, 'Temperature/train.csv'))
    val_df = pd.read_csv(os.path.join(opt.data_dir, 'Temperature/val.csv'))
    metadata_df = pd.read_json(os.path.join(opt.data_dir, 'Temperature/metadata.json'))

    train_ids = train_df.profile_id.unique()
    val_ids = val_df.profile_id.unique()

    train_samples = []
    train_dfs = {}
    for x in train_ids:
        train_dfs[x] = train_df[train_df['profile_id'] == x]
        train_samples += get_samples(x, train_dfs[x].shape[0], opt)

    val_samples = [] 
    val_dfs = {}
    for x in val_ids:
        val_dfs[x] = val_df[val_df['profile_id'] == x]
        val_samples += get_samples(x, val_dfs[x].shape[0], opt)

    return train_dfs, val_dfs, metadata_df, train_samples, val_samples


def temperature_loader(full_load, sample, metadata, opt, type='flat'):
    inp_quants = opt.inp_quants.split(',')
    out_quants = opt.out_quants.split(',')
    
    inp_df = full_load[sample[2]][inp_quants]
    out_df = full_load[sample[2]][out_quants]

    inp_data = inp_df[sample[0]:sample[1]][inp_quants] 
    inp_data -= metadata['mean'][inp_quants]
    inp_data /= metadata['std'][inp_quants]

    out_data = out_df[sample[0]:sample[1]][out_quants] 
    out_data -= metadata['mean'][out_quants]
    out_data /= metadata['std'][out_quants]

    return np.asarray(inp_data).transpose(1, 0), np.asarray(out_data).transpose(1, 0)


def ResampleLinear1D(original, targetLen):
    original = np.array(original, dtype='float')
    index_arr = np.linspace(0, len(original)-1, num=targetLen, dtype='float')
    index_floor = np.array(index_arr, dtype='int') #Round down
    index_ceil = index_floor + 1
    index_rem = index_arr - index_floor #Remain

    val1 = original[index_floor]
    val2 = original[index_ceil % len(original)]
    interp = val1 * (1.0-index_rem) + val2 * index_rem
    assert(len(interp) == targetLen)
    return interp


def fault_subsample(experiment,label):
    new_time = int(experiment['t'].shape[0] / 30)
    nexperiment = {}
    nexperiment['i_d'] = ResampleLinear1D(experiment['Id'][:, 0], new_time)
    nexperiment['i_q'] = ResampleLinear1D(experiment['Iq'][:, 0], new_time)
    nexperiment['u_d'] = ResampleLinear1D(experiment['Vd'][:, 0], new_time)
    nexperiment['u_q'] = ResampleLinear1D(experiment['Vq'][:, 0], new_time)
    nexperiment['time'] = ResampleLinear1D(experiment['t'][:, 0], new_time)
    
    nexperiment['Trigger'] = ResampleLinear1D(experiment['Trigger'][:, 0], new_time)
    nexperiment['Vib_acpi'] = ResampleLinear1D(experiment['Vib_acpi'][:, 0], new_time)
    nexperiment['Vib_axial'] = ResampleLinear1D(experiment['Vib_axial'][:, 0], new_time)
    nexperiment['Vib_base'] = ResampleLinear1D(experiment['Vib_base'][:, 0], new_time)
    nexperiment['Vib_carc'] = ResampleLinear1D(experiment['Vib_carc'][:, 0], new_time)
    nexperiment['fault'] = np.asarray([label for x in range(new_time)])

    return nexperiment

def load_fault_data(opt):
    metadata_df = pd.read_json(os.path.join(opt.data_dir, 'fault_data/clean_data/metadata.json'))

    class_map = {'rs': 0, 'r1b': 1, 'r2b': 2, 'r3b': 3, 'r4b': 4}

    train_torques = ['05', '15', '25', '35']
    val_torques = ['10', '20', '30', '40']

    dataset = {}
    train_samples = []
    val_samples = []

    for fault in ['rs', 'r1b', 'r2b', 'r3b', 'r4b']:
        for train_torque in train_torques:
            experiment_name = f'{fault}_torque{train_torque}'
            experiment = loadmat(os.path.join(opt.data_dir, 
                                              f'fault_data/clean_data/{experiment_name}.mat'))
            experiment = fault_subsample(experiment, class_map[fault])
            dataset[experiment_name] = experiment
            train_samples += get_samples(experiment_name, experiment['time'].shape[0], opt)

        for val_torque in val_torques:
            experiment_name = f'{fault}_torque{val_torque}'
            experiment = loadmat(os.path.join(opt.data_dir, 
                                              f'fault_data/clean_data/{experiment_name}.mat'))
            experiment = fault_subsample(experiment, class_map[fault])
            dataset[experiment_name] = experiment
            val_samples += get_samples(experiment_name, experiment['time'].shape[0], opt)
    
    return dataset, metadata_df, train_samples, val_samples
    

def fault_loader(full_load, sample, metadata, opt, type='flat'):
    
    inp_quants = opt.inp_quants.split(',')
    out_quants = opt.out_quants.split(',')

    inp_data = []
    for inp_quant in inp_quants:
        window = full_load[sample[2]][inp_quant][sample[0]: sample[1]]
        mean = metadata['mean'][inp_quant]
        std = metadata['std'][inp_quant]
        window = (window - mean) / (std + 1e-6)
        inp_data.append(window)

    out_data = []
    for out_quant in out_quants:
        window = full_load[sample[2]][out_quant][sample[0]: sample[1]]
        out_data.append(window)

    return np.asarray(inp_data), np.asarray(out_data)


def speed_torque_subsample(experiment):
    new_time = int(experiment['time'].shape[1] / 50)
    nexperiment = {}
    nexperiment['i_d'] = ResampleLinear1D(experiment['current_d'][0, :], new_time)
    nexperiment['i_q'] = ResampleLinear1D(experiment['current_q'][0, :], new_time)
    nexperiment['u_d'] = ResampleLinear1D(experiment['voltage_d'][0, :], new_time)
    nexperiment['u_q'] = ResampleLinear1D(experiment['voltage_q'][0, :], new_time)
    nexperiment['speed'] = ResampleLinear1D(experiment['speed'][0, :], new_time)
    nexperiment['torque'] = ResampleLinear1D(experiment['torque'][0, :], new_time)
    nexperiment['time'] = ResampleLinear1D(experiment['time'][0, :], new_time)
    
    return nexperiment

def speed_torque_missing(experiment, opt):
    fail_quants = opt.fail_quants.split(",")
    inp_quants = opt.inp_quants.split(",")
    fail_quants_prob = list(map(float, opt.fail_quants_prob.split(",")))
    
    prob_map = {}
    for i in range(len(fail_quants)):
        prob_map[fail_quants[i]] = fail_quants_prob[i]
    
    M = experiment['time'].shape[0]
    for iq in inp_quants:
        if iq in prob_map:
            c = int(prob_map[iq] * M) 
            A = np.random.randn(M)
            msk = np.zeros((M), dtype=bool)
            msk[:c] = True
            np.random.shuffle(msk)
            experiment[iq + '_mask'] = msk
            delta = [0]
            for i in range(1, msk.shape[0]):
                if msk[i] == 1:
                    delta.append(delta[i-1] + 1)
                else:
                    delta.append(1)
            experiment[iq + '_delta'] = np.asarray(delta)
        else:
            experiment[iq + '_mask'] = np.zeros((M), dtype=bool) 
            experiment[iq + '_delta'] = np.asarray([0] + [1] * (M-1))

    return experiment

def speed_torque_loader(full_load, sample, metadata, opt, type='flat'):
    inp_quants = opt.inp_quants.split(',')
    out_quants = opt.out_quants.split(',')

    inp_data = []
    if len(opt.fail_quants):
        mask_data = []
        delta_data = []
    
    for inp_quant in inp_quants:
        window = full_load[sample[2]][inp_quant][sample[0]: sample[1]]
        mean = metadata['mean'][inp_quant]
        std = metadata['std'][inp_quant]
        window = (window - mean) / (std + 1e-6)
        inp_data.append(window)

        if len(opt.fail_quants):
            mask_window = full_load[sample[2]][inp_quant + '_mask'][sample[0]: sample[1]]
            mask_data.append(mask_window)

            delta_window = full_load[sample[2]][inp_quant + '_delta'][sample[0]: sample[1]]
            delta_data.append(delta_window)
        
    out_data = []
    for out_quant in out_quants:
        window = full_load[sample[2]][out_quant][sample[0]: sample[1]]
        mean = metadata['mean'][inp_quant]
        std = metadata['std'][inp_quant]
        window = (window - mean) / (std + 1e-6)
        out_data.append(window)

    if len(opt.fail_quants):
        return np.asarray(inp_data), np.asarray(out_data), np.asarray(mask_data), np.asarray(delta_data)

    return np.asarray(inp_data), np.asarray(out_data), None, None

def load_speed_torque_data(opt):
    metadata_df = pd.read_json(os.path.join(opt.data_dir, 'Data_27012021_noisy/metadata.json'))
    train_mats = os.listdir(os.path.join(opt.data_dir, 'Data_27012021_noisy/train/'))
    val_mats = os.listdir(os.path.join(opt.data_dir, 'Data_27012021_noisy/val/'))

    train_dataset = {}
    val_dataset = {}
    train_samples = []
    val_samples = []

    for train_mat in train_mats:
        if '.mat' in train_mat:
            experiment = loadmat(os.path.join(opt.data_dir, 
                                            'Data_27012021_noisy/train/',
                                            train_mat))
            experiment = speed_torque_subsample(experiment)
            if len(opt.fail_quants):
                experiment = speed_torque_missing(experiment, opt)
            train_dataset[train_mat] = experiment
            train_samples += get_samples(train_mat, experiment['time'].shape[0], opt)

    for val_mat in val_mats:
        if '.mat' in val_mat:
            experiment = loadmat(os.path.join(opt.data_dir, 
                                            'Data_27012021_noisy/val/',
                                            val_mat))
            experiment = speed_torque_subsample(experiment)
            if len(opt.fail_quants):
                experiment = speed_torque_missing(experiment, opt)
            val_dataset[val_mat] = experiment
            val_samples += get_samples(val_mat, experiment['time'].shape[0], opt)
    
    return train_dataset, val_dataset, metadata_df, train_samples, val_samples
    

class FlatInFlatOut(data.Dataset):
    def __init__(self, full_load, samples, metadata, loader, opt):
        random.shuffle(samples)
        self.samples = samples
        self.full_load = full_load
        self.metadata = metadata
        self.opt = opt
        self.loader = loader

    def __getitem__(self, index):
        sample = self.samples[index]

        if len(self.opt.fail_quants):
            inp_seq, out_seq, mask_seq, delta_seq = self.loader(self.full_load, sample, self.metadata, self.opt, 'seq')
            return inp_seq.flatten(), out_seq.flatten(), mask_seq.flatten(), delta_seq.flatten()
        else:
            inp_seq, out_seq = self.loader(self.full_load, sample, self.metadata, self.opt, 'flat')
            inp_seq = inp_seq.flatten()
            return inp_seq, out_seq

    def __len__(self):
        return len(self.samples)


class SeqInFlatOut(data.Dataset):
    def __init__(self, full_load, samples, metadata, loader, opt):
        random.shuffle(samples)
        self.samples = samples
        self.full_load = full_load
        self.metadata = metadata
        self.opt = opt
        self.loader = loader

    def __getitem__(self, index):
        sample = self.samples[index]

        inp_seq, out_seq = self.loader(self.full_load, sample, self.metadata, self.opt, 'flat')

        return inp_seq, out_seq

    def __len__(self):
        return len(self.samples)


class SeqInSeqOut(data.Dataset):
    def __init__(self, full_load, samples, metadata, loader, opt):
        random.shuffle(samples)
        self.samples = samples
        self.full_load = full_load
        self.metadata = metadata
        self.opt = opt
        self.loader = loader

    def __getitem__(self, index):
        sample = self.samples[index]
        

        inp_seq, out_seq, mask_seq, delta_seq = self.loader(self.full_load, sample, self.metadata, self.opt, 'seq')

        if len(self.opt.fail_quants):
            return inp_seq, out_seq, mask_seq, delta_seq
        else:
            return inp_seq, out_seq

    def __len__(self):
        return len(self.samples)


def _get_prelaoder_class(opt):
    if 'fnn' in opt.model:
        return FlatInFlatOut
    if 'gain' in opt.impute_model or 'gan2stage' in opt.impute_model or 'e2e' in opt.impute_model or 'sgan' in opt.impute_model:
        return SeqInSeqOut
    if 'cnn' in opt.model:
        return SeqInFlatOut
    if 'rnn' in opt.model or 'lstm' in opt.model or 'encdec' in opt.model  or 'unet' in opt.model:
        return SeqInSeqOut
    if 'grud' in opt.impute_model or 'mrnn' in opt.impute_model or 'brits' in opt.impute_model:
        return SeqInSeqOut


def get_dataloaders(opt):
    preloader_class = _get_prelaoder_class(opt)

    if opt.dataset_name == 'Temperature':
        train_dfs, val_dfs, metadata, train_samples, val_samples = load_temperature_data(opt)
    
        print ('Train Samples ', len(train_samples))
        print ('Val Samples ', len(val_samples))

        train_preloader = preloader_class(train_dfs, train_samples, metadata, temperature_loader, opt)
        val_preloader = preloader_class(val_dfs, val_samples, metadata, temperature_loader, opt)
    
    if opt.dataset_name == 'FaultVibration':
        dataset, metadata, train_samples, val_samples = load_fault_data(opt)
    
        print ('Train Samples ', len(train_samples))
        print ('Val Samples ', len(val_samples))

        train_preloader = preloader_class(dataset, train_samples, metadata, fault_loader, opt)
        val_preloader = preloader_class(dataset, val_samples, metadata, fault_loader, opt)

    if opt.dataset_name == 'SpeedTorque':
        train_dataset, val_dataset, metadata, train_samples, val_samples = load_speed_torque_data(opt)
    
        print ('Train Samples ', len(train_samples))
        print ('Val Samples ', len(val_samples))

        train_preloader = preloader_class(train_dataset, train_samples, metadata, speed_torque_loader, opt)
        val_preloader = preloader_class(val_dataset, val_samples, metadata, speed_torque_loader, opt)
    
    train_loader = DataLoader(train_preloader, batch_size=opt.batch_size,
                            shuffle=True, num_workers=opt.num_workers)

    val_loader = DataLoader(val_preloader, batch_size=opt.batch_size,
                            shuffle=True, num_workers=opt.num_workers)

    return train_loader, val_loader, metadata