import math
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions.normal as norm

from scipy.optimize import minimize
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import time
import os

path = os.getcwd()

# basic function

def load_data(file, seed, va_ratio=0.1, te_ratio=0.1):
    # Load data
    data = np.loadtxt(path + '/data/' + file + '.txt')
    x_al = data[:, :-1]
    y_al = data[:, -1].reshape(-1, 1)

    # train, valid, test split
    x_tr, x_te, y_tr, y_te = train_test_split(
        x_al, y_al, test_size=te_ratio, random_state=seed)
    x_tr, x_va, y_tr, y_va = train_test_split(
        x_tr, y_tr, test_size=va_ratio, random_state=seed)
    s_tr_x = StandardScaler().fit(x_tr)
    s_tr_y = StandardScaler().fit(y_al)
    x_tr = s_tr_x.transform(x_tr)
    x_va = s_tr_x.transform(x_va)
    x_te = s_tr_x.transform(x_te)
    y_tr = s_tr_y.transform(y_tr)
    y_va = s_tr_y.transform(y_va)
    y_te = s_tr_y.transform(y_te)
    y_al = s_tr_y.transform(y_al)
    X_tr = df_to_tensor(x_tr)
    Y_tr = df_to_tensor(y_tr)
    X_va = df_to_tensor(x_va)
    Y_va = df_to_tensor(y_va)
    X_te = df_to_tensor(x_te)
    Y_te = df_to_tensor(y_te)
    Y_al = df_to_tensor(y_al)
    y_range = max(Y_al) - min(Y_al)

    print(file + ' - gmm regression')

    return X_tr, X_va, X_te, Y_tr, Y_va, Y_te, y_al, y_range


def t_to_n(x):
    return x.cpu().detach().numpy()


def df_to_tensor(df):
    return torch.from_numpy(np.array(df)).float().cuda()


# GMM related function

def norm_pdf(x, m, s):
    return 1 / math.sqrt(2 * math.pi) * (1 / (s + 1e-7)) * torch.exp(-(x - m).pow(2) / (2 * s.pow(2) + 1e-7))


def lognorm_pdf(x, m, s):
    return -(x - m).pow(2) / (2 * s.pow(2) + 1e-7) - torch.log(s + 1e-7) - math.log(math.sqrt(2 * math.pi))


class EarlyStopping:
    def __init__(self, patience=0, verbose=0):
        self._step = 0
        self._loss = float('inf')
        self.patience = patience
        self.verbose = verbose

    def validate(self, loss):
        if self._loss <= loss:
            self._step += 1
            if self._step > self.patience:
                if self.verbose:
                        print('Training process is stopped early....')

                return True
        else:
            self._step = 0
            self._loss = loss

        return False
