import os
import pickle
from os import path as pt
from os.path import join
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
def kurtosis_torch(x, dim=(0, 1), excess=True, dropdims=True):
    # x = x - x.mean(dim, keepdims=True)
    mean_x = torch.mean(x, 1)
    xm =x.sub(mean_x.repeat((x.size(1),1,1)).permute(1,0,2))
    x_4 = torch.pow(xm, 4).mean(dim=1)
    x_var2 = torch.pow(torch.var(x,dim=1),2)
    kurtosis = x_4 / x_var2
    if excess:
        kurtosis = kurtosis - 3
    # if dropdims:
    #     kurtosis = kurtosis[0, 0]
    del mean_x,xm,x_4,x_var2,x
    return kurtosis
def skew_torch(x, dim=(0, 1), dropdims=True):
    mean_x = torch.mean(x, 1)
    xm =x.sub(mean_x.repeat((x.size(1),1,1)).permute(1,0,2))
    x_3 = torch.pow(xm, 3).mean(dim=1)
    x_std_3 = torch.pow(torch.std(x,dim=1),3)
    skew = x_3 / x_std_3
    # if dropdims:
    #     skew = skew[0, 0]
    del mean_x, xm, x_std_3,x_3,x
    return skew

def savefig(filename, directory):
    plt.savefig(join(directory, filename))
    plt.close()


from os import path as pt


def batch_corr(x):
    mean_x = torch.mean(x, 1)
    xm = x.sub(mean_x.repeat(x.size(1), 1, 1).reshape(x.size()))
    c = torch.bmm(xm.permute(0, 2, 1), xm)
    c = c / (x.size(1) - 1)

    # normalize covariance matrix
    d = torch.diagonal(c, offset=0, dim1=1, dim2=2)
    stddev = torch.pow(d, 0.5)
    c = c.div(stddev.repeat(c.size(1), 1, 1).reshape(c.size()))
    c = c.div(stddev.repeat(c.size(1), 1, 1).reshape(c.size()).permute(0, 2, 1))

    c = torch.clamp(c, -1.0, 1.0)
    indices = torch.triu_indices(x.size(2), x.size(2), 1)
    correlations = c[:, indices[0], indices[1]]
    return correlations


def getavg_corr(x, rdf):
    temp = rdf.loc[x.index]
    temp = temp + np.random.uniform(0, 1, temp.shape) * 1e-3

    return temp.corr().mean().mean()


def plt_figures(x_fake_gs, x_future_g, x_past_g, step_number, plt_directory, sample_kinds):
    prefix = 'Diffusion_'
    save_figs = True

    for co, x_fake_g in enumerate(x_fake_gs):
        addition = sample_kinds[co]
        prefix = prefix + '_' + addition

        weights = np.random.lognormal(1.05, 1.01, (x_fake_g.size(2), 1))
        weights = np.maximum(weights, 1)
        weights = weights / sum(weights)
        torch.save(x_fake_g, plt_data_directory + '\\x_fake_' + str(step_number) + '.pt')
        torch.save(x_future_g, plt_data_directory + '\\x_real_' + str(step_number) + '.pt')
        generator_path = model_directory + '\\generator_weights_' + str(step_number) + '.pth'
        discriminator_path = model_directory + '\\discriminator_weights_' + str(step_number) + '.pth'

        # Save the model weights
        torch.save(self.G.state_dict(), generator_path)
        if self.gan_algo not in ['CTVAE', 'CTNF']:
            torch.save(self.D.state_dict(), discriminator_path)
        fakepathweight = pd.DataFrame(np.array([(x_fake_g[j].detach().cpu().numpy() ** 2) @ weights for j in
                                                range(np.minimum(x_fake_g.size(0), 32))]).reshape(-1,
                                                                                                  x_fake_g.size(
                                                                                                      1))).T.cumsum().apply(
            lambda x: np.sqrt(x))
        realpathweight = pd.DataFrame(np.array([(x_future_g[j].detach().cpu().numpy() ** 2) @ weights for j in
                                                range(np.minimum(x_fake_g.size(0), 32))]).reshape(-1,
                                                                                                  x_fake_g.size(
                                                                                                      1))).T.cumsum().apply(
            lambda x: np.sqrt(x))
        difference = (realpathweight - fakepathweight).T / (realpathweight.index + 1)
        pd.Series(difference.T.tail(10).values.flatten()).hist(bins=100, density=True)
        plt.title(prefix + 'WeightedPathSumDiff' + str(step_number))
        if save_figs:
            savefig(addition + 'WeightedPathSumDiff_' + str(step_number) + '.png', plt_directory)
        else:
            plt.show()

        pd.Series(np.stack([pd.DataFrame(x_future_g[j].detach().cpu().numpy()).corr().values.flatten() for j in
                            range(np.minimum(x_fake_g.size(0), 32))]).flatten()).hist(bins=100, density=True,
                                                                                      alpha=0.5,
                                                                                      label='Future_real')

        pd.Series(np.stack([pd.DataFrame(x_past_g[j].detach().cpu().numpy()).corr().values.flatten() for j in
                            range(np.minimum(x_fake_g.size(0), 32))]).flatten()).hist(bins=100, density=True,
                                                                                      alpha=0.5, label='Past_real')
        pd.Series(np.stack([pd.DataFrame(x_fake_g[j].detach().cpu().numpy()).corr().values.flatten() for j in
                            range(np.minimum(x_fake_g.size(0), 32))]).flatten()).hist(bins=100, density=True,
                                                                                      alpha=0.5, label='Generated')
        plt.legend()
        plt.title(prefix + '_Correlation_' + str(step_number))
        plt.legend()
        if save_figs:
            savefig(addition + 'Correlation_' + str(step_number) + '.png', plt_directory)
        else:
            plt.show()
        pd.Series(pd.DataFrame(
            [z[0].rolling(int(x_fake_g.size(1) / 5)).apply(lambda x: getavg_corr(x, z)).dropna().values
             for z in [pd.DataFrame(x_future_g[j].detach().cpu().numpy()) for j in
                       range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
                                                                                         alpha=0.5,
                                                                                         label='Future_real')
        pd.Series(pd.DataFrame(
            [z[0].rolling(int(x_fake_g.size(1) / 5)).apply(lambda x: getavg_corr(x, z)).dropna().values
             for z in [pd.DataFrame(x_past_g[j].detach().cpu().numpy()) for j in
                       range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
                                                                                         alpha=0.5,
                                                                                         label='Past_real')
        pd.Series(pd.DataFrame(
            [z[0].rolling(int(x_fake_g.size(1) / 5)).apply(lambda x: getavg_corr(x, z)).dropna().values
             for z in [pd.DataFrame(x_fake_g[j].detach().cpu().numpy()) for j in
                       range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
                                                                                         alpha=0.5,
                                                                                         label='Generated')

        plt.title(prefix + '_RollingCorrelation' + str(step_number))
        plt.legend()
        if save_figs:
            savefig(addition + '_RollingCorrelation' + str(step_number) + '.png', plt_directory)

        else:
            plt.show()

        pd.Series(np.stack([pd.DataFrame(x_future_g[j].detach().cpu().numpy()).values.flatten() for j in
                            range(np.minimum(x_fake_g.size(0), 32))]).flatten()).hist(bins=100, density=True,
                                                                                      alpha=0.5,
                                                                                      label='Future_real')
        pd.Series(np.stack([pd.DataFrame(x_past_g[j].detach().cpu().numpy()).values.flatten() for j in
                            range(np.minimum(x_fake_g.size(0), 32))]).flatten()).hist(bins=100, density=True,
                                                                                      alpha=0.5, label='Past_real')
        pd.Series(np.stack([pd.DataFrame(x_fake_g[j].detach().cpu().numpy()).values.flatten() for j in
                            range(np.minimum(x_fake_g.size(0), 32))]).flatten()).hist(bins=100, density=True,
                                                                                      alpha=0.5, label='Generated')
        plt.title(prefix + '_Returns' + str(step_number))
        plt.legend()
        if save_figs:
            savefig(addition + '_Returns' + str(step_number) + '.png', plt_directory)

        else:
            plt.show()
        pd.Series(pd.DataFrame(
            [z[0].rolling(int(x_fake_g.size(1) / 3)).mean().dropna().values
             for z in [pd.DataFrame(x_future_g[j].detach().cpu().numpy()) for j in
                       range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
                                                                                         alpha=0.5,
                                                                                         label='Future_real')
        pd.Series(pd.DataFrame(
            [z[0].rolling(int(x_fake_g.size(1) / 3)).mean().dropna().values
             for z in [pd.DataFrame(x_past_g[j].detach().cpu().numpy()) for j in
                       range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
                                                                                         alpha=0.5,
                                                                                         label='Past_real')

        pd.Series(pd.DataFrame(
            [z[0].rolling(int(x_fake_g.size(1) / 3)).mean().dropna().values
             for z in [pd.DataFrame(x_fake_g[j].detach().cpu().numpy()) for j in
                       range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
                                                                                         alpha=0.5,
                                                                                         label='Generated')

        plt.title(prefix + '_RollingMean' + str(step_number))
        plt.legend()
        if save_figs:
            savefig(addition + '_RollingMean' + str(step_number) + '.png', plt_directory)

        else:
            plt.show()
        pd.Series(
            np.stack([pd.DataFrame(x_future_g[j].detach().cpu().numpy()).std(axis=1).values.flatten() for j in
                      range(np.minimum(x_fake_g.size(0), 32))]).flatten()).hist(bins=100, density=True,
                                                                                alpha=0.5, label='Future_real')

        pd.Series(
            np.stack([pd.DataFrame(x_past_g[j].detach().cpu().numpy()).std(axis=1).values.flatten() for j in
                      range(np.minimum(x_fake_g.size(0), 32))]).flatten()).hist(bins=100, density=True,
                                                                                alpha=0.5, label='Past_real')
        pd.Series(
            np.stack([pd.DataFrame(x_fake_g[j].detach().cpu().numpy()).std(axis=1).values.flatten() for j in
                      range(np.minimum(x_fake_g.size(0), 32))]).flatten()).hist(bins=100, density=True,
                                                                                alpha=0.5, label='Generated')
        plt.title(prefix + '_std_' + str(step_number))
        plt.legend()
        if save_figs:
            savefig(addition + '_std_' + str(step_number) + '.png', plt_directory)

        else:
            plt.show()

        pd.Series(pd.DataFrame(
            [z[0].rolling(int(x_fake_g.size(1) / 3)).std().dropna().values
             for z in [pd.DataFrame(x_future_g[j].detach().cpu().numpy()) for j in
                       range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
                                                                                         alpha=0.5,
                                                                                         label='Future_real')
        pd.Series(pd.DataFrame(
            [z[0].rolling(int(x_fake_g.size(1) / 3)).std().dropna().values
             for z in [pd.DataFrame(x_past_g[j].detach().cpu().numpy()) for j in
                       range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
                                                                                         alpha=0.5,
                                                                                         label='Past_real')
        pd.Series(pd.DataFrame(
            [z[0].rolling(int(x_fake_g.size(1) / 3)).std().dropna().values
             for z in [pd.DataFrame(x_fake_g[j].detach().cpu().numpy()) for j in
                       range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
                                                                                         alpha=0.5,
                                                                                         label='Generated')
        # pd.Series(pd.DataFrame(                [z[0].rolling(int(x_fake_g.size(1) / 3)).std().dropna().values
        #      for z in [pd.DataFrame(x_future_g[j].detach().cpu().numpy()) for j in
        #                range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
        #                                                                                 alpha=0.5)
        # pd.Series(pd.DataFrame(
        #     [z[0].rolling(int(x_fake_g.size(1) / 3)).std().dropna().values
        #      for z in [pd.DataFrame(x_fake_g[j].detach().cpu().numpy()) for j in
        #                range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
        #                                                                                 alpha=0.5)

        plt.title(prefix + '_RollingStandardDeviation' + str(step_number))
        plt.legend()
        if save_figs:
            savefig(addition + '_RollingStandardDeviation' + str(step_number) + '.png', plt_directory)

        else:
            plt.show()

        pd.Series(
            np.stack([pd.DataFrame(x_future_g[j].detach().cpu().numpy()).kurt(axis=1).values.flatten() for j in
                      range(np.minimum(x_fake_g.size(0), 32))]).flatten()).hist(bins=100, density=True,
                                                                                alpha=0.5, label='Future_real')

        pd.Series(
            np.stack([pd.DataFrame(x_future_g[j].detach().cpu().numpy()).kurt(axis=1).values.flatten() for j in
                      range(np.minimum(x_past_g.size(0), 32))]).flatten()).hist(bins=100, density=True,
                                                                                alpha=0.5, label='Past_real')
        pd.Series(
            np.stack([pd.DataFrame(x_fake_g[j].detach().cpu().numpy()).kurt(axis=1).values.flatten() for j in
                      range(np.minimum(x_fake_g.size(0), 32))]).flatten()).hist(bins=100, density=True,
                                                                                alpha=0.5, label='Generated')
        plt.title(prefix + '_Kurtosis_' + str(step_number))
        plt.legend()
        if save_figs:
            savefig(addition + '_Kurtosis_' + str(step_number) + '.png', plt_directory)

        else:
            plt.show()
        pd.Series(pd.DataFrame(
            [z[0].rolling(int(x_fake_g.size(1) / 3)).kurt().dropna().values
             for z in [pd.DataFrame(x_future_g[j].detach().cpu().numpy()) for j in
                       range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
                                                                                         alpha=0.5,
                                                                                         label='Future_real')
        pd.Series(pd.DataFrame(
            [z[0].rolling(int(x_fake_g.size(1) / 3)).kurt().dropna().values
             for z in [pd.DataFrame(x_past_g[j].detach().cpu().numpy()) for j in
                       range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
                                                                                         alpha=0.5,
                                                                                         label='Past_real')
        pd.Series(pd.DataFrame(
            [z[0].rolling(int(x_fake_g.size(1) / 3)).kurt().dropna().values
             for z in [pd.DataFrame(x_fake_g[j].detach().cpu().numpy()) for j in
                       range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
                                                                                         alpha=0.5,
                                                                                         label='Generated')

        plt.title(prefix + '_RollingKurtosis' + str(step_number))
        plt.legend()
        if save_figs:
            savefig(addition + '_RollingKurtosis' + str(step_number) + '.png', plt_directory)

        else:
            plt.show()
        pd.Series(
            np.stack([pd.DataFrame(x_future_g[j].detach().cpu().numpy()).skew(axis=1).values.flatten() for j in
                      range(np.minimum(x_fake_g.size(0), 32))]).flatten()).hist(bins=100, density=True,
                                                                                alpha=0.5, label='Future_real')
        pd.Series(
            np.stack([pd.DataFrame(x_past_g[j].detach().cpu().numpy()).skew(axis=1).values.flatten() for j in
                      range(np.minimum(x_fake_g.size(0), 32))]).flatten()).hist(bins=100, density=True,
                                                                                alpha=0.5, label='Past_real')
        pd.Series(
            np.stack([pd.DataFrame(x_fake_g[j].detach().cpu().numpy()).skew(axis=1).values.flatten() for j in
                      range(np.minimum(x_fake_g.size(0), 32))]).flatten()).hist(bins=100, density=True,
                                                                                alpha=0.5, label='Generated')
        plt.title(prefix + '_skew_' + str(step_number))

        plt.legend()
        if save_figs:
            savefig(addition + '_skew_' + str(step_number) + '.png', plt_directory)

        else:
            plt.show()

        pd.Series(pd.DataFrame([x.apply(lambda x: x.autocorr()) for x in
                                [pd.DataFrame(x_future_g[j].detach().cpu().numpy()) for j in
                                 range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100,
                                                                                                   density=True,
                                                                                                   alpha=0.5,
                                                                                                   label='Future_real')
        pd.Series(pd.DataFrame([x.apply(lambda x: x.autocorr()) for x in
                                [pd.DataFrame(x_past_g[j].detach().cpu().numpy()) for j in
                                 range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100,
                                                                                                   density=True,
                                                                                                   alpha=0.5,
                                                                                                   label='Past_real')
        pd.Series(pd.DataFrame([x.apply(lambda x: x.autocorr()) for x in
                                [pd.DataFrame(x_fake_g[j].detach().cpu().numpy()) for j in
                                 range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100,
                                                                                                   density=True,
                                                                                                   alpha=0.5,
                                                                                                   label='Generated')
        plt.legend()
        if save_figs:
            savefig('_ACF_' + str(step_number) + '.png', plt_directory)

        else:
            plt.show()

        pd.Series(pd.DataFrame([x.apply(lambda x: x.abs().autocorr()) for x in
                                [pd.DataFrame(x_future_g[j].detach().cpu().numpy()) for j in
                                 range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100,
                                                                                                   density=True,
                                                                                                   alpha=0.5,
                                                                                                   label='Future_real')
        pd.Series(pd.DataFrame([x.apply(lambda x: x.abs().autocorr()) for x in
                                [pd.DataFrame(x_past_g[j].detach().cpu().numpy()) for j in
                                 range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100,
                                                                                                   density=True,
                                                                                                   alpha=0.5,
                                                                                                   label='Past_real')
        pd.Series(pd.DataFrame([x.apply(lambda x: x.abs().autocorr()) for x in
                                [pd.DataFrame(x_fake_g[j].detach().cpu().numpy()) for j in
                                 range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100,
                                                                                                   density=True,
                                                                                                   alpha=0.5,
                                                                                                   label='Generated')
        plt.legend()
        if save_figs:
            savefig('_ACF_abs_' + str(step_number) + '.png', plt_directory)

        else:
            plt.show()

        pd.Series(pd.DataFrame(
            [z[0].rolling(int(x_fake_g.size(1) / 3)).skew().dropna().values
             for z in [pd.DataFrame(x_future_g[j].detach().cpu().numpy()) for j in
                       range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
                                                                                         alpha=0.5,
                                                                                         label='Future_real')
        pd.Series(pd.DataFrame(
            [z[0].rolling(int(x_fake_g.size(1) / 3)).skew().dropna().values
             for z in [pd.DataFrame(x_past_g[j].detach().cpu().numpy()) for j in
                       range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
                                                                                         alpha=0.5,
                                                                                         label='Past_real')
        pd.Series(pd.DataFrame(
            [z[0].rolling(int(x_fake_g.size(1) / 3)).skew().dropna().values
             for z in [pd.DataFrame(x_fake_g[j].detach().cpu().numpy()) for j in
                       range(np.minimum(x_fake_g.size(0), 32))]]).values.flatten()).hist(bins=100, density=True,
                                                                                         alpha=0.5,
                                                                                         label='Generated')
        plt.legend()
        plt.title(prefix + '_RollingSkew' + str(step_number))
        if save_figs:
            savefig(addition + '_RollingSkew' + str(step_number) + '.png', plt_directory)

        else:
            plt.show()
        pd.DataFrame(x_future_g[0].detach().cpu().numpy()).cumsum().plot()
        plt.title(prefix + '_RealCumsum_' + str(step_number))
        if save_figs:
            savefig(addition + '_RealCumsum_' + str(step_number) + '.png', plt_directory)

        else:
            plt.show()

        pd.DataFrame(x_past_g[0].detach().cpu().numpy()).cumsum().plot()
        plt.title(prefix + '_Real_pastCumsum_' + str(step_number))
        if save_figs:
            savefig(addition + '_Real_pastCumsum_' + str(step_number) + '.png', plt_directory)

        else:
            plt.show()
        pd.DataFrame(x_fake_g[0].detach().cpu().numpy()).cumsum().plot()
        plt.title(prefix + '_FakeCumsum_' + str(step_number))

        if save_figs:
            savefig(addition + '_FakeCumsum_' + str(step_number) + '.png', plt_directory)

        else:
            plt.show()


def return_emd(dff, dfr, opt, nbins=30):
    dists = []
    if opt == 'rolling':

        j, bins1 = pd.cut(np.hstack((dff.detach().cpu().numpy().flatten(), dfr.detach().cpu().numpy().flatten())),
                          nbins, retbins=True)
        j1 = pd.cut(dff.detach().cpu().numpy().flatten(), bins=bins1)
        j2 = pd.cut(dfr.detach().cpu().numpy().flatten(), bins=bins1)
        a = j1.value_counts().sort_index()
        a = a / a.sum()
        b = j2.value_counts().sort_index()
        b = b / b.sum()
        a1 = [x.mid for x in a.index]
        b1 = [x.mid for x in a.index]
        M = ot.dist(np.array(a1).reshape((nbins, 1)), np.array(b1).reshape((nbins, 1)))
        M /= M.max()
        CRPS = np.sum((a.cumsum() - b.cumsum()) ** 2)
        distmat = ot.emd(a.values, b.values, M)
        otdist = torch.tensor(np.trace(np.dot(np.transpose(distmat), M))).to(dff.device)
        a1 = [x + 1e-6 for x in a]
        a1 = a1 / np.sum(a1)
        b1 = [x + 1e-6 for x in b]
        b1 = b1 / np.sum(b1)
        total_m = 0.5 * (a1 + b1)
        JS = 0.5 * F.kl_div(torch.tensor(a1).log(), torch.tensor(total_m), reduction='batchmean')
        JS += 0.5 * F.kl_div(torch.tensor(b1).log(), torch.tensor(total_m), reduction='batchmean')
        dist = torch.tensor(sum((a - b) ** 2)).to(dff.device)
    else:
        i2 = dff.detach().cpu().numpy()
        i1 = dfr.detach().cpu().numpy()
        j, bins1 = pd.cut(np.hstack((i2.flatten(), i1.flatten())), nbins, retbins=True)
        j1 = pd.cut(i2.flatten(), bins=bins1)
        j2 = pd.cut(i1.flatten(), bins=bins1)
        a = j1.value_counts().sort_index()
        a = a / a.sum()
        b = j2.value_counts().sort_index()
        b = b / b.sum()
        a1 = [x.mid for x in a.index]
        b1 = [x.mid for x in a.index]
        M = ot.dist(np.array(a1).reshape((nbins, 1)), np.array(b1).reshape((nbins, 1)))
        M /= M.max()
        CRPS = np.sum((a.cumsum() - b.cumsum()) ** 2)
        distmat = ot.emd(a.values, b.values, M)
        otdist = torch.tensor(np.trace(np.dot(np.transpose(distmat), M))).to(dff.device)
        dist = torch.tensor(sum((a - b) ** 2)).to(dff.device)
        total_m = 0.5 * (a + b)
        a1 = [x + 1e-6 for x in a]
        a1 = a1 / np.sum(a1)
        b1 = [x + 1e-6 for x in b]
        b1 = b1 / np.sum(b1)
        total_m = 0.5 * (a1 + b1)
        JS = 0.5 * F.kl_div(torch.tensor(a1).log(), torch.tensor(total_m), reduction='batchmean')
        JS += 0.5 * F.kl_div(torch.tensor(b1).log(), torch.tensor(total_m), reduction='batchmean')
        del i1, i2
    del a, b, j1, j2, j, bins1
    if opt == 'rolling':
        return dist, CRPS, otdist, JS
    else:
        return dist, CRPS, otdist, JS


def get_etf_data(**kwargs):
    N = kwargs['N']
    conn_params = {
        'dbname': '',
        'user': '',
        'password': '',
        'host': '',
        'port': ''
    }

    # Connect to PostgreSQL
    conn = psycopg2.connect(**conn_params)
    cursor = conn.cursor()

    query = """


    WITH ranked_symbols AS (
        SELECT symbol,
               AVG(volume) AS avg_volume,
               RANK() OVER (ORDER BY AVG(volume) DESC) AS volume_rank
        FROM public.etf_1hour_adjsplitdiv
        WHERE datetime >= '2020-03-20'
        GROUP BY symbol
    )
    SELECT
           public.etf_1hour_adjsplitdiv.*
    FROM (
        SELECT symbol
        FROM ranked_symbols
        WHERE volume_rank <= """ + str(N) + """
    ) AS top_20
    JOIN public.etf_1hour_adjsplitdiv ON top_20.symbol = public.etf_1hour_adjsplitdiv.symbol
    where public.etf_1hour_adjsplitdiv.datetime >= '2013-03-20'



    """
    cursor.execute(query)
    conn.commit()
    rows = cursor.fetchall()
    cursor.close()
    conn.close()
    df = pd.DataFrame(rows, columns=[desc[0] for desc in cursor.description])  # Replace column names as needed
    Prices = pd.pivot_table(df, index='datetime', columns='symbol', values='close')
    filt = Prices.isna().sum(axis=1) < 40
    Prices = Prices[filt]
    missing_percent = Prices.isnull().mean() * 100

    columns_to_drop = missing_percent[missing_percent > 30].index
    Prices = Prices.drop(columns=columns_to_drop)
    Prices_ = Prices.ffill().fillna(0).pct_change().replace(np.inf, 0)
    return torch.tensor(Prices_.dropna().values)


def get_stock_data(**kwargs):

    if kwargs['T'] == '1D':

        Prices_ = pd.read_hdf('C:\\Users\\username\\PycharmProjects\\GPS\\Stockreturnsdaily.h5')
    else:
        Prices_ = pd.read_hdf('C:\\Users\\username\\PycharmProjects\\GPS\\stock_5min.h5')
    return torch.tensor(Prices_.fillna(0).values)



def get_crypto_data(**kwargs):
    N = kwargs['N']

    conn_params = {
        'dbname': '',
        'user': '',
        'password': '',
        'host': '',
        'port': ''
    }

    # Connect to PostgreSQL
    conn = psycopg2.connect(**conn_params)
    cursor = conn.cursor()

    # Execute the SELECT query
    query = """
       WITH ranked_data AS (
        SELECT *,
               ROW_NUMBER() OVER (PARTITION BY symbol ORDER BY datetime DESC) AS rn
        FROM public.crypto_1hour where symbol not in ('UST','DOGE','DAI','USDT')
    )
    SELECT *
    FROM ranked_data
    WHERE rn <= 100000
    """
    cursor.execute(query)

    rows = cursor.fetchall()


    cursor.close()
    conn.close()

    df = pd.DataFrame(rows, columns=[desc[0] for desc in cursor.description])  # Replace column names as needed
    Prices = pd.pivot_table(df, index='datetime', columns='symbol', values='close')
    filt = Prices.isna().sum(axis=1) < 40
    Prices = Prices[filt]
    missing_percent = Prices.isnull().mean() * 100

    columns_to_drop = missing_percent[missing_percent > 30].index

    Prices = Prices.drop(columns=columns_to_drop)
    Prices_ = Prices.ffill().fillna(0).pct_change().replace(np.inf, 0)
    df['Notional'] = df['close'] * df['volume']
    Notional = pd.pivot_table(df.loc[df['symbol'].isin(Prices.columns)], index='datetime', columns='symbol',
                              values='Notional').tail(3000)
    top20crypto = [x for x in Notional.sum().sort_values(ascending=False).head(N).index]
    return torch.tensor(Prices_[top20crypto].dropna().values)


def histo_loss(real, fake, step):
    count = 0
    output = 0
    loss_dictionary = {}
    for opt in ['rolling', 'full']:
        loss_dictionary[opt] = {}
        if opt in 'rolling':
            loss_temp = []
            for i in range(1):
                loss_temp_1 = []
                div1 = np.random.choice([2, 3, 4, 5], 1)[0]

                fake1 = fake
                real1 = real
                for features in ['mean', 'corr', 'std', 'skew', 'kurt']:

                    if features == 'mean':
                        real12 = torch.tensor(np.array([z[0].rolling(int(real.size(1) / div1)).mean().dropna().values
                                                        for z in
                                                        [pd.DataFrame(real[j].detach().cpu().numpy()) for j in
                                                         range(np.minimum(real.size(0), 32))]])).to(fake.device)

                        fake11 = torch.tensor(np.array([z[0].rolling(int(real.size(1) / div1)).mean().dropna().values
                                                        for z in
                                                        [pd.DataFrame(fake[j].detach().cpu().numpy()) for j in
                                                         range(np.minimum(real.size(0), 32))]])).to(fake.device)

                    elif features == 'std':
                        real12 = torch.tensor(np.array([z[0].rolling(int(real.size(1) / div1)).std().dropna().values
                                                        for z in [pd.DataFrame(real[j].detach().cpu().numpy()) for j in
                                                                  range(np.minimum(real.size(0), 32))]])).to(
                            fake.device)

                        fake11 = torch.tensor(np.array([z[0].rolling(int(real.size(1) / div1)).std().dropna().values
                                                        for z in
                                                        [pd.DataFrame(fake[j].detach().cpu().numpy()) for j in
                                                         range(np.minimum(real.size(0), 32))]])).to(fake.device)

                    elif features == 'corr':
                        real12 = torch.tensor(np.array(
                            [z[0].rolling(int(real.size(1) / div1)).apply(lambda x: getavg_corr(x, z)).dropna().values
                             for z in [pd.DataFrame(real[j].detach().cpu().numpy()) for j in
                                       range(np.minimum(real.size(0), 32))]])).to(fake.device)
                        fake11 = torch.tensor(np.array(
                            [z[0].rolling(int(fake.size(1) / div1)).apply(lambda x: getavg_corr(x, z)).dropna().values
                             for z in [pd.DataFrame(fake[j].detach().cpu().numpy()) for j in
                                       range(np.minimum(fake.size(0), 32))]])).to(fake.device)


                    elif features == 'skew':
                        real12 = torch.tensor(np.array([z[0].rolling(int(real.size(1) / div1)).skew().dropna().values
                                                        for z in
                                                        [pd.DataFrame(real[j].detach().cpu().numpy()) for j in
                                                         range(np.minimum(real.size(0), 32))]])).to(fake.device)

                        fake11 = torch.tensor(np.array([z[0].rolling(int(real.size(1) / div1)).skew().dropna().values
                                                        for z in
                                                        [pd.DataFrame(fake[j].detach().cpu().numpy()) for j in
                                                         range(np.minimum(real.size(0), 32))]])).to(fake.device)

                    elif features == 'kurt':
                        real12 = torch.tensor(
                            np.array([z[0].rolling(int(real.size(1) / div1)).kurt().dropna().values
                                      for z in
                                      [pd.DataFrame(real[j].detach().cpu().numpy()) for j in
                                       range(np.minimum(real.size(0), 32))]])).to(fake.device)

                        fake11 = torch.tensor(
                            np.array([z[0].rolling(int(real.size(1) / div1)).kurt().dropna().values
                                      for z in
                                      [pd.DataFrame(fake[j].detach().cpu().numpy()) for j in
                                       range(np.minimum(real.size(0), 32))]])).to(fake.device)

                    l, CRPS, otdist, JS = return_emd(fake11, real12, opt, nbins=40)
                    loss_dictionary[opt][features] = {'SqPdfLoss': l.item(), 'CRPS': CRPS, 'otdist': otdist.item(),
                                                      'JS': JS.item()}
                    loss_temp_1.append(l)

                    del l
                loss_temp.append(torch.sum(torch.stack(loss_temp_1)))
            output = output + torch.mean(torch.stack(loss_temp))
            del fake11, loss_temp_1, loss_temp, real12
        else:
            for features in ['std', 'corr', 'normal', 'mean', 'skew', 'kurt', 'acf', 'acf_abs']:
                if features == 'normal':
                    real2 = real
                    fake1 = fake

                elif features == 'mean':
                    real2 = real.mean(dim=1)

                    fake1 = fake.mean(dim=1)

                elif features == 'corr':
                    real2 = batch_corr(real)

                    fake1 = batch_corr(fake)

                elif features == 'std':
                    real2 = real.std(dim=1)

                    fake1 = fake.std(dim=1)

                elif features == 'skew':
                    real2 = skew_torch(real)
                    fake1 = skew_torch(fake)

                elif features == 'kurt':
                    real2 = kurtosis_torch(real)
                    fake1 = kurtosis_torch(fake)
                elif features == 'acf':
                    real2 = torch.tensor(
                        np.array([pd.DataFrame(z).apply(lambda x: x.autocorr())
                                  for z in
                                  [pd.DataFrame(real[j].add(1e-9).detach().cpu().numpy()) for j in
                                   range(np.minimum(real.size(0), 32))]])).to(fake.device)
                    fake1 = torch.tensor(
                        np.array([pd.DataFrame(z).apply(lambda x: x.autocorr())
                                  for z in
                                  [pd.DataFrame(fake[j].add(1e-9).detach().cpu().numpy()) for j in
                                   range(np.minimum(real.size(0), 32))]])).to(fake.device)
                elif features == 'acf_abs':
                    real2 = torch.tensor(
                        np.array([pd.DataFrame(z).apply(lambda x: x.abs().autocorr())
                                  for z in
                                  [pd.DataFrame(real[j].add(1e-9).detach().cpu().numpy()) for j in
                                   range(np.minimum(real.size(0), 32))]])).to(fake.device)
                    fake1 = torch.tensor(
                        np.array([pd.DataFrame(z).apply(lambda x: x.abs().autocorr())
                                  for z in
                                  [pd.DataFrame(fake[j].add(1e-9).detach().cpu().numpy()) for j in
                                   range(np.minimum(real.size(0), 32))]])).to(fake.device)
                # print(features)
                loss, CRPS, otdist, JS = return_emd(fake1, real2, opt, nbins=40)
                loss_dictionary[opt][features] = {'SqPdfLoss': loss.item(), 'CRPS': CRPS, 'otdist': otdist.item(),
                                                  'JS': JS.item()}

                if (step + 1 % 199) == 0:
                    print(loss, features)
                if count == 0:
                    output = loss
                    count += 1
                else:

                    output += loss
                del fake1, real2
    del real, fake, loss
    if (step + 1 % 199) == 0:
        print(output, 'total_loss')
    return output, loss_dictionary


def corrcoef(x):
    mean_x = torch.mean(x, 1)
    xm = x.sub(mean_x.repeat(x.size(1)).reshape(x.size()))
    c = xm.mm(xm.t())
    c = c / (x.size(1) - 1)

    d = torch.diag(c)
    stddev = torch.pow(d, 0.5)
    c = c.div(stddev.repeat(c.size(1)).reshape(c.size()))
    c = c.div(stddev.repeat(c.size(1)).reshape(c.size()).t())

    c = torch.clamp(c, -1.0, 1.0)

    return c


def corr(x_batch: torch.Tensor) -> torch.Tensor:
    if len(x_batch.shape) == 2:
        x_batch = x_batch.unsqueeze(0)

    n_features = x_batch.shape[2]
    indices = torch.triu_indices(n_features, n_features, 1)

    correlations = []
    for x in x_batch:
        correlation = corrcoef(x.T.clone())
        correlation = correlation[indices[0], indices[1]]

        correlations.append(torch.nan_to_num(correlation))

    correlations = torch.stack(correlations)

    return correlations



def generate_correlation_matrix(n, high_corr_value=0.15, low_corr_value=0.5, max_attempts=1000):
    corr_matrix = np.zeros((n, n))
    high_corr_count = int(n * (n - 1) / 2)  # Number of high correlation values
    high_corr_indices = np.triu_indices(n, k=1)
    corr_matrix[high_corr_indices] = high_corr_value
    corr_matrix[high_corr_indices] = corr_matrix[high_corr_indices]
    remaining_indices = np.random.choice(high_corr_count, high_corr_count // 2,
                                         replace=False)
    low_corr_indices = (high_corr_indices[0][remaining_indices], high_corr_indices[1][remaining_indices])

    corr_matrix[low_corr_indices] = low_corr_value

    corr_matrix += corr_matrix.T

    np.fill_diagonal(corr_matrix, 1)

    return corr_matrix


def generate_cbm(N,props=[0.5,.3,0.2],corrvalues=[0.9,0.5,0.2]):


    j1 = props[0]
    j2 = props[1]
    j3 = props[2]
    correlation_matrix = generate_correlation_matrix(int(N * j1), high_corr_value=corrvalues[0], low_corr_value=corrvalues[0])
    correlation_matrix2 = generate_correlation_matrix(int(j2 * N) , high_corr_value=corrvalues[1],
                                                      low_corr_value=corrvalues[1])
    correlation_matrix3 = generate_correlation_matrix(int(j3 * N), high_corr_value=corrvalues[2], low_corr_value=corrvalues[2])
    return correlation_matrix, correlation_matrix2,correlation_matrix3
def lurie_goldberg_transform(A, n, epsilon=1e-9):
    attempts = 0
    max_attempts = 100
    while attempts < max_attempts:
        B = np.sqrt(A)
        # B = A
        eigenvalues, eigenvectors = np.linalg.eig(B)

        # Adjust negative eigenvalues
        # print(np.sum([eigenvalues < epsilon]))

        eigenvalues[eigenvalues < epsilon] = 0
        # Reconstruct matrix
        B_pos_def = eigenvectors.dot(np.diag(eigenvalues)).dot(eigenvectors.T)
        A = B_pos_def ** 2 - np.diag(B_pos_def ** 2) * np.identity(n)
        np.fill_diagonal(A, 1)

        try:
            np.linalg.cholesky(A)
            print(attempts)
            return A
        except np.linalg.LinAlgError:
            attempts += 1
    print(attempts)
    raise ValueError(
        "Failed to generate a positive semidefinite correlation matrix within the maximum number of attempts.")


class Pipeline:
    def __init__(self, steps):
        """ Pre- and postprocessing pipeline. """
        self.steps = steps

    def transform(self, x, until=None):
        x = x.clone()
        for n, step in self.steps:
            if n == until:
                break
            x = step.transform(x)
        return x

    def inverse_transform(self, x, until=None):
        for n, step in self.steps[::-1]:
            if n == until:
                break
            x = step.inverse_transform(x)
        return x


class StandardScalerTS():
    """ Standard scales a given (indexed) input vector along the specified axis. """

    def __init__(self, axis=(1)):
        self.mean = None
        self.std = None
        self.axis = axis

    def transform(self, x):
        if self.mean is None:
            self.mean = torch.mean(x, dim=self.axis)
            self.std = torch.std(x, dim=self.axis)
        return (x - self.mean.to(x.device)) / self.std.to(x.device)

    def inverse_transform(self, x):
        return x * self.std.to(x.device) + self.mean.to(x.device)


def get_MGBM(**kwargs):
    N = kwargs['N']
    mu = kwargs['mu']
    sigma = kwargs['sigma']
    T = kwargs['T']
    dt = kwargs['dt']
    S_0 = kwargs['S_0']
    seq_length = kwargs['seqlength']
    bs = kwargs['bs']
    corr_matrix = kwargs['corr_matrix']
    mu = np.random.normal(0, 0.01, N)
    sigma = np.random.normal(sigma, sigma / 3, N)
    dmin = -6 * np.max(sigma) * np.sqrt(dt)
    dmax = -dmin
    if corr_matrix == None:
        #         N=np.shape(S)[0]
        correlation = np.zeros((N, N))
        correlation[np.tril_indices_from(correlation)] = np.random.normal(0.5, .0, int(N * (N + 1) / 2))
        correlation = correlation + correlation.T
        np.fill_diagonal(correlation, 1)
    choleskyMatrix = np.linalg.cholesky(correlation)
    # e = np.random.normal(size = (nProcesses, nSteps))
    m = (mu - 0.5 * sigma ** 2) * dt
    S = S_0 * np.ones((N, T + 1))
    BM = sigma[:, np.newaxis] * np.sqrt(dt) * choleskyMatrix @ np.random.normal(0, 1, (N, T))
    paths = m[:, np.newaxis] + BM
    return torch.tensor(paths).unsqueeze(0).permute(0, 2, 1)


def simulate_v(**kwargs):
    # def simulate_ngarch(N, T, params, S0):
    N = kwargs['N']
    T = kwargs['T']
    mu = kwargs['mu']
    omega = kwargs['omega']
    corr_coef = kwargs['corr_coef']
    alpha = kwargs['alpha']  # AR coefficient
    beta = kwargs['beta']  # rate of reversion
    gamma = kwargs['gamma']  # leverage effect
    S = kwargs['S_0'] * np.ones((N, T))
    R = np.nan * np.ones((N, T))
    v0 = omega / (1 - alpha - beta)
    v = v0 * np.ones((N, 1))
    BM = np.random.normal(0, 1, (N, T))
    correlation = np.zeros((N, N))
    if corr_coef == None:
        corr_coef = .3
    correlation[np.tril_indices_from(correlation)] = np.random.normal(corr_coef, 0, int(N * (N + 1) / 2))
    correlation = correlation + correlation.T
    np.fill_diagonal(correlation, 1)

    choleskyMatrix = np.linalg.cholesky(correlation)
    for i in range(1, T, 1):
        sigma = np.sqrt(v)
        m = (mu - 0.5 * sigma ** 2)
        returnt = m.flatten() + sigma.flatten() * choleskyMatrix @ BM[:, i]
        R[:, i] = returnt
        S[:, i] = S[:, i - 1] * np.exp(returnt)
        v = omega + alpha * v.flatten() + beta * (
                sigma.flatten() * choleskyMatrix @ BM[:, i] - gamma * np.sqrt(v).flatten()) ** 2

    return torch.tensor(R[:, 1:]).unsqueeze(0).float()


def simulate_Heston_corr(**kwargs):
    # form https://d-nb.info/1027389406/34
    N = kwargs['N']
    mu = kwargs['mu']
    sigmav = kwargs['sigmav']
    T = kwargs['T']
    S_0 = kwargs['S_0']

    corr_matrix = kwargs['corr_matrix']
    theta = kwargs['theta']
    alpha = kwargs['alpha']
    rho = kwargs['rho']
    V_0 = kwargs['V_0']
    corr_coef = kwargs['corr_coef']
    dt = 1 / T
    # mu = np.zeros(2)
    cov = np.identity(N)

    S = np.full(shape=(T + 1, N), fill_value=S_0)
    R = np.full(shape=(T + 1, N), fill_value=np.nan)
    v = np.full(shape=(T + 1, N), fill_value=V_0)
    # sampling correlated brownian motions under risk-neutral measure
    Z = np.random.multivariate_normal(mu, cov, (T, N))
    # corr_matrix = None
    if corr_matrix == None:
        #         N=np.shape(S)[0]
        correlation = np.zeros((N, N))
        if corr_coef == None:
            corr_coef = .3
        correlation[np.tril_indices_from(correlation)] = np.random.normal(corr_coef, 0, int(N * (N + 1) / 2))
        correlation = correlation + correlation.T
        np.fill_diagonal(correlation, 1)

    choleskyMatrix = np.linalg.cholesky(correlation)

    move_noise = Z[:, :, 0]
    varaince_noise = Z[:, :, 1]
    for i in range(1, T + 1):
        new_move_noise = choleskyMatrix @ move_noise[i - 1, :]
        vnoise = new_move_noise * rho + np.sqrt(1 - rho ** 2) * varaince_noise[i - 1, :]
        R[i] = (mu - 0.5 * v[i - 1]) * dt + np.sqrt(v[i - 1] * dt) * new_move_noise
        S[i] = S[i - 1] * np.exp(R[i])
        # = np.log(S[i]/S[i-1])
        v[i] = np.maximum(v[i - 1] + alpha * (theta - v[i - 1]) * dt + sigmav * np.sqrt(v[i - 1] * dt) * vnoise,
                          theta * 1e-1)
    return torch.tensor(R).unsqueeze(0).float()


def simulate_Heston_corr1(**kwargs):
    N = kwargs['N']
    mu = kwargs['mu']
    # sigma = kwargs['sigma']
    sigmav = kwargs['sigmav']
    T = kwargs['T']
    # dt = kwargs['dt']
    S_0 = kwargs['S_0']
    # seq_length = kwargs['seqlength']
    # bs = kwargs['bs']
    corr_matrix = kwargs['corr_matrix']
    theta = kwargs['theta']
    alpha = kwargs['alpha']
    rho = kwargs['rho']
    V_0 = kwargs['V_0']
    corr_coef = kwargs['corr_coef']
    dt = 1 / T
    cov = np.identity(N)
    j3 = 0.5
    j1 = .45
    j2 = j3 - j1
    S = np.full(shape=(T + 1, N), fill_value=S_0)
    R = np.full(shape=(T + 1, N), fill_value=np.nan)
    v = np.full(shape=(T + 1, N), fill_value=V_0)
    # sampling correlated brownian motions under risk-neutral measure
    Z = np.random.multivariate_normal(mu, cov, (T, N))

    correlation_matrix = generate_correlation_matrix(int(N * j1), high_corr_value=.9, low_corr_value=0.9)
    correlation_matrix2 = generate_correlation_matrix(int(j3 * N) - int(N * j1), high_corr_value=0.2,
                                                      low_corr_value=0.2)
    correlation_matrix3 = generate_correlation_matrix(int(j3 * N), high_corr_value=0.5, low_corr_value=0.5)
    choleskyMatrix = np.linalg.cholesky(correlation_matrix)
    choleskyMatrix2 = np.linalg.cholesky(correlation_matrix2)
    choleskyMatrix3 = np.linalg.cholesky(correlation_matrix3)

    correlation_matrix4 = generate_correlation_matrix(int(N * j1), high_corr_value=.99, low_corr_value=0.99)
    correlation_matrix5 = generate_correlation_matrix(int(j3 * N) - int(N * j1), high_corr_value=.5, low_corr_value=.5)
    correlation_matrix6 = generate_correlation_matrix(int(j3 * N), high_corr_value=.7, low_corr_value=.7)

    correlation_matrix7 = generate_correlation_matrix(int(N * j1), high_corr_value=.4, low_corr_value=0.4)
    correlation_matrix8 = generate_correlation_matrix(int(j3 * N) - int(N * j1), high_corr_value=.2, low_corr_value=.2)
    correlation_matrix9 = generate_correlation_matrix(int(j3 * N), high_corr_value=.3, low_corr_value=.3)

    move_noise = Z[:, :, 0]
    varaince_noise = Z[:, :, 1]
    for i in range(1, T + 1):
        new_move_noise1 = choleskyMatrix @ move_noise[i - 1, :int(N * j1)]
        new_move_noise2 = choleskyMatrix2 @ move_noise[i - 1, int(N * j1):int(N * j3)]
        new_move_noise3 = choleskyMatrix3 @ move_noise[i - 1, int(N * j3):]

        vnoise1 = new_move_noise1 * rho + np.sqrt(1 - rho ** 2) * varaince_noise[i - 1, :int(N * j1)]
        vnoise2 = new_move_noise2 * rho + np.sqrt(1 - rho ** 2) * varaince_noise[i - 1, int(N * j1):int(N * j3)]
        vnoise3 = new_move_noise3 * rho + np.sqrt(1 - rho ** 2) * varaince_noise[i - 1, int(N * j3):]

        R[i, :int(N * j1)] = (mu[:int(N * j1)] - 0.5 * v[i - 1, :int(N * j1)]) * dt + np.sqrt(
            v[i - 1, :int(N * j1)] * dt) * new_move_noise1
        S[i, :int(N * j1)] = S[i - 1, :int(N * j1)] * np.exp(R[i, :int(N * j1)])
        # = np.log(S[i]/S[i-1])
        v[i, :int(N * j1)] = np.maximum(
            v[i - 1, :int(N * j1)] + alpha[:int(N * j1)] * (theta[:int(N * j1)] - v[i - 1, :int(N * j1)]) * dt + sigmav[
                                                                                                                 :int(
                                                                                                                     N * j1)] * np.sqrt(
                v[i - 1, :int(N * j1)] * dt) * vnoise1,
            theta[:int(N * j1)] * 1e-1)

        R[i, int(N * j1):int(N * j3)] = (mu[int(N * j1):int(N * j3)] - 0.5 * v[i - 1,
                                                                             int(N * j1):int(N * j3)]) * dt + np.sqrt(
            v[i - 1, int(N * j1):int(N * j3)] * dt) * new_move_noise2
        S[i, int(N * j1):int(N * j3)] = S[i - 1, int(N * j1):int(N * j3)] * np.exp(R[i, int(N * j1):int(N * j3)])
        v[i, int(N * j1):int(N * j3)] = np.maximum(
            v[i - 1, int(N * j1):int(N * j3)] + alpha[int(N * j1):int(N * j3)] * (
                    theta[int(N * j1):int(N * j3)] - v[i - 1, int(N * j1):int(N * j3)]) * dt + sigmav[
                                                                                               int(N * j1):int(
                                                                                                   N * j3)] * np.sqrt(
                v[i - 1, int(N * j1):int(N * j3)] * dt) * vnoise2,
            theta[int(N * j1):int(N * j3)] * 1e-1)

        R[i, int(N * j3):] = (mu[int(N * j3):] - 0.5 * v[i - 1, int(N * j3):]) * dt + np.sqrt(
            v[i - 1, int(N * j3):] * dt) * new_move_noise3
        S[i, int(N * j3):] = S[i - 1, int(N * j3):] * np.exp(R[i, int(N * j3):])
        v[i, int(N * j3):] = np.maximum(
            v[i - 1, int(N * j3):] + alpha[int(N * j3):] * (theta[int(N * j3):] - v[i - 1, int(N * j3):]) * dt + sigmav[
                                                                                                                 int(N * j3):] * np.sqrt(
                v[i - 1, int(N * j3):] * dt) * vnoise3,
            theta[int(N * j3):] * 1e-1)
        if np.mean(v[i]) > 0.03:
            choleskyMatrix = np.linalg.cholesky(correlation_matrix4)
            choleskyMatrix2 = np.linalg.cholesky(correlation_matrix5)
            choleskyMatrix3 = np.linalg.cholesky(correlation_matrix6)

        else:
            if np.median(R[i - 5:i]) > 0.00015:
                choleskyMatrix = np.linalg.cholesky(correlation_matrix7)
                choleskyMatrix2 = np.linalg.cholesky(correlation_matrix8)
                choleskyMatrix3 = np.linalg.cholesky(correlation_matrix9)

            else:
                choleskyMatrix = np.linalg.cholesky(correlation_matrix)
                choleskyMatrix2 = np.linalg.cholesky(correlation_matrix2)
                choleskyMatrix3 = np.linalg.cholesky(correlation_matrix3)

    return torch.tensor(R).unsqueeze(0).float()


def simulate_Heston_corr_many(**kwargs):
    # form https://d-nb.info/1027389406/34
    N = kwargs['N']
    T = kwargs['T']
    Blocks = kwargs['Blocks']
    T = T + 100
    dt = 1 / T
    cov = np.identity(N)
    Rs = []

    for k in range(Blocks):
        mu = np.random.normal(0.0, 0.15 / 252, N)
        sigmav = np.random.normal(0.5, 0.15, N)
        S_0 = 1.0
        theta = np.random.normal((0.1), (.05), N) ** 2
        alpha = np.random.choice(range(1, 10, 1), N)
        rho = -.99
        V_0 = np.random.normal((0.1), (0), N) ** 2
        j3 = np.random.uniform(0.1, .5)
        j1 = np.random.uniform(0.1, .5)
        j2 = j3 - j1
        S = np.full(shape=(T + 1, N), fill_value=S_0)
        R = np.full(shape=(T + 1, N), fill_value=np.nan)
        v = np.full(shape=(T + 1, N), fill_value=V_0)
        v1 = np.full(shape=(T + 1, N), fill_value=V_0)
        Z = np.random.multivariate_normal(mu, cov, (T, N))
        rho1 = np.random.uniform(-1, 0)
        rho2 = np.random.uniform(-1, 0)
        hcv = np.random.uniform(0.3, .99)
        lcv = np.random.uniform(0.01, .35)
        correlation_matrix = generate_correlation_matrix(N, high_corr_value=hcv, low_corr_value=hcv)
        chm1 = np.linalg.cholesky(correlation_matrix)
        choleskyMatrix = chm1
        correlation_matrix2 = generate_correlation_matrix(N, high_corr_value=lcv, low_corr_value=lcv)
        chm2 = np.linalg.cholesky(correlation_matrix2)
        state = 0
        move_noise = Z[:, :, 0]
        varaince_noise = Z[:, :, 1]
        t = np.random.randint(20, 200)
        perc = np.random.randint(25, 75)
        for i in range(1, T + 1):
            new_move_noise1 = choleskyMatrix @ move_noise[i - 1]
            vnoise1 = new_move_noise1 * rho1 + np.sqrt(1 - rho1 ** 2) * varaince_noise[i - 1]
            R[i] = (mu - 0.5 * v[i - 1]) * dt + np.sqrt(v[i - 1] * dt) * new_move_noise1
            S[i] = S[i - 1] * np.exp(R[i])
            v1[i] = np.maximum(
                v1[i - 1] + alpha * (theta - v1[i - 1]) * dt + sigmav * np.sqrt(v1[i - 1] * dt) * vnoise1,
                theta * 1e-1)
            v[i] = v1[i]
            if i > np.maximum(t + 1, 100):
                if np.mean(v[i - t:i]) > np.nanpercentile(v[:i].flatten(), perc):
                    choleskyMatrix = chm1
                else:
                    choleskyMatrix = chm2

        Rs.append(torch.tensor(R[1:, :]).float())

    return torch.stack(Rs).reshape(-1, N).unsqueeze(0)


def vineBeta(d, betaparam):
    P = np.zeros((d, d))  # storing partial correlations
    S = np.eye(d)

    for k in range(1, d):
        for i in range(k + 1, d):
            P[k, i] = np.random.beta(betaparam, betaparam * .8)  # sampling from beta
            P[k, i] = (P[k, i] - 0.5) * 2  # linearly shifting to [-1, 1]
            p = P[k, i]
            for l in range(k - 1, -1, -1):  # converting partial correlation to raw correlation
                p = p * np.sqrt((1 - P[l, i] ** 2) * (1 - P[l, k] ** 2)) + P[l, i] * P[l, k]
            S[k, i] = p
            S[i, k] = p
    permutation = np.random.permutation(d)
    S = S[permutation][:, permutation]
    return S


def jumps_only(**kwargs):
    N = kwargs['N']
    T = kwargs['T']
    Blocks = kwargs['Blocks']
    T = T + 100
    dt = 1 / 252
    Rs = []
    ps = [(1 + np.sin(x)) / 2 for x in np.linspace(-2 * np.pi, np.pi, T + 1)]

    for k in range(Blocks):
        R = np.full(shape=(T + 1, N), fill_value=np.nan)
        lambda_j = 5
        mu_jumps = 0
        sigma_jumps = .1

        for i in range(1, T + 1):
            jump_count = np.random.poisson(lambda_j * dt, ((N, 1)))
            jump_size = np.random.normal(mu_jumps * (jump_count - lambda_j * dt), np.sqrt(jump_count) * sigma_jumps,
                                         ((N, 1)))
            R[i] = jump_size.flatten()
            pp = np.random.uniform(0, 1)
            if pp > ps[i]:
                # print(i)
                theta = (np.random.lognormal(0.1, .35, N) * .15) ** 2
                lambda_j = 3
                mu_jumps = 0.
                sigma_jumps = .1
            else:
                theta = (np.random.lognormal(0.1, .35, N) * .35) ** 2
                lambda_j = 10
                mu_jumps = 0.
                sigma_jumps = .3
        Rs.append(torch.tensor(R[1:, :]).float())

    return torch.stack(Rs).reshape(-1, N).unsqueeze(0)


def simulate_Heston_corr_many2(**kwargs):
    N = kwargs['N']
    T = kwargs['T']
    Blocks = kwargs['Blocks']
    jumps = kwargs['Jumps']
    # kwargs)['Jumps']
    if kwargs['perc'] is not None:
        perc = kwargs['perc']
    else:
        perc = 70
    memory = kwargs['memory']
    T = T + 100
    dt = 1 / 252
    cov = np.identity(N)
    Rs = []

    # t = np.random.randint(20, int(T / 2))
    if memory is not None:
        t = memory
    else:
        t = 10
    # perc = 50
    chmc1 = 0
    chmc2 = 0
    ps = [(1 + np.sin(x)) / 2 for x in np.linspace(-2 * np.pi, 2 * np.pi, T + 1)]

    for k in range(Blocks):
        mu = np.random.normal(0.0, 0.15 / 252, N)
        sigmav = np.random.normal(0.5, 0.15, N)
        S_0 = 1.0
        pp = np.random.uniform(0, 1)
        if pp > 0.25:
            theta = (np.random.lognormal(0.1, .35, N) * .15) ** 2

        else:
            theta = (np.random.lognormal(0.1, .35, N) * .15) ** 2

        alpha = np.random.choice(range(1, 10, 1), N)
        V_0 = np.random.normal((0.1), (0), N) ** 2
        j3 = np.random.uniform(0.1, .5)
        j1 = np.random.uniform(0.1, .5)
        j2 = j3 - j1
        S = np.full(shape=(T + 1, N), fill_value=S_0)
        R = np.full(shape=(T + 1, N), fill_value=np.nan)
        v = np.full(shape=(T + 1, N), fill_value=V_0)
        v1 = np.full(shape=(T + 1, N), fill_value=V_0)
        Z = np.random.multivariate_normal(mu, cov, (T, N))
        rho1 = np.random.uniform(-1, 0)
        rho2 = np.random.uniform(-1, 0)
        hcv = .85
        # lcv = np.random.uniform(0.01, .35)
        correlation_matrix = generate_correlation_matrix(N, high_corr_value=hcv, low_corr_value=hcv)
        chm2 = np.linalg.cholesky(correlation_matrix)
        choleskyMatrix = chm2
        if k != 0:
            if N > 100:
                correlation_matrix2 = pd.read_hdf(str(N) + '_' + str(T) + '_' + '.h5').values
            else:
                correlation_matrix2 = vineBeta(N, N * .06)

        else:
            correlation_matrix2 = vineBeta(N, N * .06)
            if N > 100:
                pd.DataFrame(correlation_matrix2).to_hdf(str(N) + '_' + str(T) + '_' + '.h5', 'df')
        chm1 = np.linalg.cholesky(correlation_matrix2)
        state = 0
        move_noise = Z[:, :, 0]
        varaince_noise = Z[:, :, 1]
        lambda_j = 5
        mu_jumps = 0
        sigma_jumps = .1

        # np.random.randint(25, 75)
        for i in range(1, T + 1):
            if jumps:
                jump_count = np.random.poisson(lambda_j * dt, ((N, 1)))
                jump_size = np.random.normal(mu_jumps * (jump_count - lambda_j * dt), np.sqrt(jump_count) * sigma_jumps,
                                             ((N, 1)))
            new_move_noise1 = choleskyMatrix @ move_noise[i - 1]
            vnoise1 = new_move_noise1 * rho1 + np.sqrt(1 - rho1 ** 2) * varaince_noise[i - 1]
            R[i] = (mu - 0.5 * v[i - 1]) * dt + np.sqrt(v[i - 1] * dt) * new_move_noise1
            if jumps:
                R[i] += jump_size.flatten()
            R[i] = np.minimum(0.64, np.abs(R[i])) * np.sign(R[i])
            S[i] = S[i - 1] * np.exp(R[i])
            v1[i] = np.maximum(
                v1[i - 1] + alpha * (theta - v1[i - 1]) * dt + sigmav * np.sqrt(v1[i - 1] * dt) * vnoise1,
                theta * 1e-1)
            v[i] = v1[i]
            if i > np.maximum(t + 1, 100):
                if np.mean(v[i - t:i]) > np.nanpercentile(v[:i].flatten(), perc):
                    choleskyMatrix = chm2
                    chmc1 += 1

                else:
                    choleskyMatrix = chm1
                    chmc2 += 1

            pp = np.random.uniform(0, 1)
            if pp > ps[i]:
                # print(i)
                theta = (np.random.lognormal(0.1, .35, N) * .15) ** 2
                lambda_j = 3
                mu_jumps = 0.
                sigma_jumps = .2
            else:
                theta = (np.random.lognormal(0.1, .35, N) * .35) ** 2
                lambda_j = 20
                mu_jumps = 0.
                sigma_jumps = .5
        Rs.append(torch.tensor(R[1:, :]).float())

    return torch.stack(Rs).reshape(-1, N).unsqueeze(0)


def simulate_Heston_corr_many2old(**kwargs):
    N = kwargs['N']
    T = kwargs['T']
    Blocks = kwargs['Blocks']
    T = T + 100
    dt = 1 / 252
    cov = np.identity(N)
    Rs = []
    t = np.random.randint(20, 200)
    perc = np.random.randint(60, 95)
    chmc1 = 0
    chmc2 = 0
    for k in range(Blocks):
        mu = np.random.normal(0.0, 0.15 / 252, N)
        sigmav = np.random.normal(0.5, 0.15, N)
        S_0 = 1.0
        pp = np.random.uniform(0, 1)
        if pp > 0.25:
            theta = np.random.normal((0.1), (.05), N) ** 2
        else:
            theta = np.random.normal((0.7), (.05), N) ** 2
        alpha = np.random.choice(range(1, 10, 1), N)
        V_0 = np.random.normal((0.1), (0), N) ** 2
        j3 = np.random.uniform(0.1, .5)
        j1 = np.random.uniform(0.1, .5)
        j2 = j3 - j1
        S = np.full(shape=(T + 1, N), fill_value=S_0)
        R = np.full(shape=(T + 1, N), fill_value=np.nan)
        v = np.full(shape=(T + 1, N), fill_value=V_0)
        v1 = np.full(shape=(T + 1, N), fill_value=V_0)
        Z = np.random.multivariate_normal(mu, cov, (T, N))
        rho1 = np.random.uniform(-1, 0)
        rho2 = np.random.uniform(-1, 0)
        hcv = .85
        correlation_matrix = generate_correlation_matrix(N, high_corr_value=hcv, low_corr_value=hcv)
        chm2 = np.linalg.cholesky(correlation_matrix)
        choleskyMatrix = chm2
        correlation_matrix2 = vineBeta(N, N * .06)
        chm1 = np.linalg.cholesky(correlation_matrix2)
        state = 0
        move_noise = Z[:, :, 0]
        varaince_noise = Z[:, :, 1]

        for i in range(1, T + 1):
            new_move_noise1 = choleskyMatrix @ move_noise[i - 1]
            vnoise1 = new_move_noise1 * rho1 + np.sqrt(1 - rho1 ** 2) * varaince_noise[i - 1]
            R[i] = (mu - 0.5 * v[i - 1]) * dt + np.sqrt(v[i - 1] * dt) * new_move_noise1
            S[i] = S[i - 1] * np.exp(R[i])
            v1[i] = np.maximum(
                v1[i - 1] + alpha * (theta - v1[i - 1]) * dt + sigmav * np.sqrt(v1[i - 1] * dt) * vnoise1,
                theta * 1e-1)
            v[i] = v1[i]
            if i > np.maximum(t + 1, 100):
                if np.mean(v[i - t:i]) > np.nanpercentile(v[:i].flatten(), perc):
                    choleskyMatrix = chm2
                    chmc1 += 1
                else:
                    choleskyMatrix = chm1
                    chmc2 += 1

        Rs.append(torch.tensor(R[1:, :]).float())

    return torch.stack(Rs).reshape(-1, N).unsqueeze(0)


def simulate_Heston_corr2(**kwargs):
    # from https://d-nb.info/1027389406/34
    N = kwargs['N']
    mu = kwargs['mu']
    sigmav = kwargs['sigmav']
    T = kwargs['T']
    S_0 = kwargs['S_0']

    corr_matrix = kwargs['corr_matrix']
    theta = kwargs['theta']
    alpha = kwargs['alpha']
    rho = kwargs['rho']
    V_0 = kwargs['V_0']
    corr_coef = kwargs['corr_coef']
    T = T + 300
    dt = 1 / T
    cov = np.identity(N)
    j3 = 0.5
    j1 = .45
    j2 = j3 - j1
    S = np.full(shape=(T + 1, N), fill_value=S_0)
    R = np.full(shape=(T + 1, N), fill_value=np.nan)
    v = np.full(shape=(T + 1, N), fill_value=V_0)
    Z = np.random.multivariate_normal(mu, cov, (T, N))

    correlation_matrix = generate_correlation_matrix(N, high_corr_value=.9, low_corr_value=0.9)
    chm1 = np.linalg.cholesky(correlation_matrix)
    choleskyMatrix = chm1

    correlation_matrix2 = generate_correlation_matrix(N, high_corr_value=.25, low_corr_value=0.25)
    chm2 = np.linalg.cholesky(correlation_matrix2)

    move_noise = Z[:, :, 0]
    varaince_noise = Z[:, :, 1]
    for i in range(1, T + 1):
        new_move_noise1 = choleskyMatrix @ move_noise[i - 1]

        vnoise1 = new_move_noise1 * rho + np.sqrt(1 - rho ** 2) * varaince_noise[i - 1]

        R[i] = (mu - 0.5 * v[i - 1]) * dt + np.sqrt(v[i - 1] * dt) * new_move_noise1
        S[i] = S[i - 1] * np.exp(R[i])
        # = np.log(S[i]/S[i-1])
        v[i] = np.maximum(v[i - 1] + alpha * (theta - v[i - 1]) * dt + sigmav * np.sqrt(v[i - 1] * dt) * vnoise1,
                          theta * 1e-1)

        if i > 300:
            if np.mean(v[i - 10:i]) > np.nanpercentile(v.flatten(), 75):  # print('Trend up',np.median(R[i-5:i]))
                choleskyMatrix = chm1

            else:
                choleskyMatrix = chm2

    return torch.tensor(R).unsqueeze(0).float()




def simulate_ngarchm(**kwargs):
    # def simulate_ngarch(N, T, params, S0):
    N = kwargs['N']
    T = kwargs['T']
    mu = kwargs['mu']
    omega = kwargs['omega']
    corr_coef = kwargs['corr_coef']
    alpha = kwargs['alpha']  # AR coefficient
    beta = kwargs['beta']  # rate of reversion
    gamma = kwargs['gamma']  # leverage effect
    S = kwargs['S_0'] * np.ones((N, T))
    R = np.nan * np.ones((N, T))
    v0 = omega / (1 - alpha - beta)
    v = v0 * np.ones((N, 1))
    BM = np.random.normal(0, 1, (N, T))
    j3 = 0.65
    j1 = .35
    j2 = j3-j1

    correlation_matrix = generate_correlation_matrix(int(N * j1), high_corr_value=.9, low_corr_value=0.9)
    correlation_matrix2 = generate_correlation_matrix(N-int(N * j1)-int(N * (1-j3)), high_corr_value=0.2,
                                                      low_corr_value=0.2)
    correlation_matrix3 = generate_correlation_matrix(int((1-j3) * N), high_corr_value=0.5, low_corr_value=0.5)
    choleskyMatrix = np.linalg.cholesky(correlation_matrix)
    choleskyMatrix2 = np.linalg.cholesky(correlation_matrix2)
    choleskyMatrix3 = np.linalg.cholesky(correlation_matrix3)
    mu = np.random.normal(mu, np.sqrt(v) / 10, np.shape(v))
    state = 0
    n = N
    state=1
    states=[]
    states.append(state)
    p_old =.05
    for i in range(1, T, 1):
        pshock=np.random.uniform(-1,1)*1e-2
        if state==1:
            thresh = 0.5 - p_old - .01
            p=np.minimum(np.maximum(np.random.uniform(-p_old/2.,thresh/2.)+p_old,0),1)+pshock
            # p=p*.5
        elif state==3:
            thresh=1-p_old-.01
            thresh2 = .65 - p_old
            p=np.minimum(np.maximum(np.random.uniform(thresh2,thresh/2.)+p_old,0),1)+pshock
        else:
            thresh2 = .85 - p_old
            thresh1=.2-p_old
            p = np.minimum(np.maximum(np.random.uniform(thresh1, thresh2/2.) + p_old, 0), 1)+pshock
        p_old=p
        # if p<0.3:
        state=1
        chm = choleskyMatrix
        chm2 = choleskyMatrix2
        chm3 = choleskyMatrix3
        states.append(state)
        sigma1 = np.sqrt(v[:int(j1 * n)])
        m1 = (mu[:int(j1 * n)].reshape(sigma1.shape) - 0.5 * sigma1 ** 2)

        sigma2 = np.sqrt(v[int(j1 * n):int(j1 * n)+correlation_matrix2.shape[0]])
        m2 = (mu[int(j1 * n):int(j1 * n)+correlation_matrix2.shape[0]].reshape(sigma2.shape) - 0.5 * sigma2 ** 2)

        sigma3 = np.sqrt(v[-int((1-j3) * n):])
        m3 = (mu[-int((1-j3) * n):].reshape(sigma3.shape) - 0.5 * sigma3 ** 2)
        returnt1 = m1.flatten() + sigma1.flatten() * chm @ BM[:int(j1 * n), i]
        R[:int(j1 * n), i] = returnt1
        S[:int(j1 * n), i] = S[:int(j1 * n), i - 1] * np.exp(returnt1)
        returnt2 = m2.flatten() + sigma2.flatten() * chm2 @ BM[int(j1 * n):int(j1 * n)+correlation_matrix2.shape[0], i]
        R[int(j1 * n):int(j1 * n)+correlation_matrix2.shape[0], i] = returnt2
        S[int(j1 * n):int(j1 * n)+correlation_matrix2.shape[0], i] = S[int(j1 * n):int(j1 * n)+correlation_matrix2.shape[0], i - 1] * np.exp(returnt2)
        returnt3 = m3.flatten() + sigma3.flatten() * chm3 @ BM[-int((1-j3) * n):, i]
        R[-int((1-j3) * n):, i] = returnt3
        S[-int((1-j3) * n):, i] = S[-int((1-j3) * n):, i - 1] * np.exp(returnt3)

        v[:int(j1 * n)] = np.array(omega + alpha * v[:int(j1 * n)].flatten() + beta * (
                sigma1.flatten() * chm @ BM[:int(j1 * n), i] - gamma * np.sqrt(
            v[:int(j1 * n)]).flatten()) ** 2).reshape(v[:int(j1 * n)].shape)

        v[int(j1 * n):int(j1 * n)+correlation_matrix2.shape[0]] = np.array(omega + alpha * v[int(j1 * n):int(j1 * n)+correlation_matrix2.shape[0]].flatten() + beta * (
                sigma2.flatten() * chm2 @ BM[int(j1 * n):int(j1 * n)+correlation_matrix2.shape[0], i] - gamma * np.sqrt(
            v[int(j1 * n):int(j1 * n)+correlation_matrix2.shape[0]]).flatten()) ** 2).reshape(v[int(j1 * n):int(j1 * n)+correlation_matrix2.shape[0]].shape)

        v[-int((1-j3) * n):] = np.array(omega + alpha * v[-int((1-j3) * n):].flatten() + beta * (
                sigma3.flatten() * chm3 @ BM[-int((1-j3) * n):, i] - gamma * np.sqrt(
            v[-int((1-j3) * n):]).flatten()) ** 2).reshape(v[-int((1-j3) * n):].shape)

    return torch.tensor(R[:, 1:]).unsqueeze(0).float(), np.corrcoef(R[:, 1:])


def get_var_dataset(window_size, batch_size=5000, dim=3, phi=0.3, sigma=0.2):
    def multi_AR(window_size, dim=3, phi=0.3, sigma=0.2, burn_in=200):
        window_size = window_size + burn_in
        xt = np.zeros((window_size, dim))
        one = np.ones(dim)
        ide = np.identity(dim)
        phi1 = np.random.normal(phi, sigma, dim)
        MU = np.zeros(dim)
        COV = sigma * one + (1 - sigma) * ide
        W = np.random.multivariate_normal(MU, COV, window_size)
        for i in range(dim):
            xt[0, i] = 0
        for t in range(window_size - 1):
            xt[t + 1] = phi1 * xt[t] + W[t] / 252
        return xt[burn_in:]

    var_samples = []
    for i in range(batch_size):
        tmp = multi_AR(window_size, dim, phi=phi, sigma=sigma)
        var_samples.append(tmp)
    data_raw = torch.from_numpy(np.array(var_samples)).float()

    def get_pipeline():
        transforms = list()
        transforms.append(('standard_scale', StandardScalerTS(axis=(0, 1))))  # standard scale
        pipeline = Pipeline(steps=transforms)
        return pipeline

    pipeline = get_pipeline()
    data_preprocessed = pipeline.transform(data_raw)
    return pipeline, data_raw, data_preprocessed


def get_arch_dataset(window_size, lag=4, bt=0.055, N=5000, dim=1):
    """
    Creates the dataset: loads data.

    :param data_path: :param t_lag: :param device: :return:
    """

    def get_raw_data(N=5000, lag=4, T=2000, omega=0.00001, bt=0.055, burn_in=2000):
        beta = bt * np.ones(lag)
        eps = np.random.randn(N, T + burn_in)
        logrtn = np.zeros((N, T + burn_in))

        initial_arch = omega / (1 - beta[0])

        arch = initial_arch + np.zeros((N, T + burn_in))

        logrtn[:, :lag] = np.sqrt(arch[:, :lag]) * eps[:, :lag]

        for t in range(lag - 1, T + burn_in - 1):
            arch[:, t + 1] = omega + np.matmul(beta.reshape(1, -1), np.square(
                logrtn[:, t - lag + 1:t + 1]).transpose())  # * (logrtn[:, t] < 0.)
            logrtn[:, t + 1] = np.sqrt(arch[:, t + 1]) * eps[:, t + 1]
        return arch[:, burn_in:], logrtn[:, burn_in:]

    pipeline = Pipeline(steps=[('standard_scale', StandardScalerTS(axis=(0, 1)))])
    _, logrtn = get_raw_data(T=window_size, N=N, bt=bt)
    data_raw = torch.from_numpy(logrtn[..., None]).float()
    data_pre = pipeline.transform(data_raw)
    return pipeline, data_raw, data_pre


def load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


def rolling_window(x, x_lag, cum=False, norm=False, corr1=False, add_batch_dim=True, **kwargs):
    if add_batch_dim:
        x = x[None, ...]
    if cum != True:
        if norm == False:
            if corr1 == False:
                return torch.cat([x[:, t:t + x_lag] for t in range(x.shape[1] - x_lag)], dim=0)
            else:
                d = torch.cat([x[:, t:t + x_lag] for t in range(x.shape[1] - x_lag)], dim=0)
                cr = corr(d)
                return d, cr
    else:
        mult = np.maximum(int(x_lag / 5), 5)
        dmin = -mult * np.max(kwargs['sigma']) * np.sqrt(kwargs['dt'])
        dmax = -dmin
        ddiv = 1 / (2 * dmax)
        norm_cumulated_paths = torch.cat(
            [(torch.cumsum(x[:, t:t + x_lag], dim=1) - dmin) * ddiv for t in range(x.shape[1] - x_lag)], dim=0)
        return norm_cumulated_paths


def transform(x_real, d):
    ranked_values, ranked_index = torch.sort(x_real, dim=d)
    original_rankings = torch.argsort(ranked_index, dim=d)
    return original_rankings, x_real.size(d), (original_rankings / x_real.size(d) - .5), ranked_values


def transform_back(example1, ranked_values, x_real, d):
    if d == 2:
        x_real = x_real.permute(0, 2, 1)
        ranked_values = ranked_values.permute(0, 2, 1)
        example1 = example1.permute(0, 2, 1)
    example = ((example1 + .5) * x_real.size(1))
    example = torch.clip(example, 0, x_real.size(1) - 1)
    example = torch.round(example).long()
    or1 = ranked_values.squeeze(0).permute(1, 0)
    example_perm = example.permute(2, 0, 1).reshape(example.size(2), -1)
    transformed = torch.gather(or1, 1, example_perm).permute(1, 0).reshape(example.size())
    if d == 2:
        return transformed.permute(0, 2, 1)
    else:
        return transformed


def get_new_rank_values(y, old_x_real):
    adde = torch.cat((y.reshape(-1, old_x_real.size(2)).unsqueeze(0), old_x_real), dim=1)
    original_rankings1, n_features1, rank_values1, ranked_values1 = transform(adde, 1)
    return rank_values1[:, :y.size(0) * y.size(1)].reshape(y.size())


def get_mit_arrythmia_dataset(filenames):
    DATA_DIR = './data/mit-bih-arrhythmia-database-1.0.0/'
    import wfdb
    records = list()
    for fn in filenames:
        records.append(wfdb.rdsamp(os.path.join(DATA_DIR, fn), sampto=3000)[0][None, ...])
    records = np.concatenate(records, axis=0)
    records = np.log(5 * (records - records.min() + 1))
    data_raw = torch.from_numpy(records).float()
    pipeline = Pipeline(steps=[('standard_scale', StandardScalerTS(axis=(0, 1)))])
    data_pre = pipeline.transform(data_raw)
    return pipeline, data_raw, data_pre


def get_data(data_type,  p, q,actual=0, **data_params):

    if data_type == 'HestonMany':
        x_real = simulate_Heston_corr_many(**data_params)
        x_real = x_real
    elif 'exp_intraday' in data_type:
        x_real = pd.read_hdf(r'C:\Users\username\PycharmProjects\GPS\intraday_10min.h5', 'df').values
        x_real = torch.tensor(x_real).unsqueeze(0).float()
    elif  'Dispersion' in data_type:
        x_real = pd.read_hdf(r'C:\Users\username\PycharmProjects\GPS\Dispersion.h5', 'df').values
        x_real = torch.tensor(x_real).unsqueeze(0).float()
    elif  'sandi' in data_type:
        x_real = pd.read_hdf(r'C:\Users\username\PycharmProjects\GPS\sandi.h5', 'df').values
        x_real = torch.tensor(x_real).unsqueeze(0).float()
        if actual !=0:
            x_real = pd.read_hdf(r'C:\Users\username\PycharmProjects\GPS\sandipreelection.h5', 'df').values
            x_real = torch.tensor(x_real).unsqueeze(0).float()
    elif 'experimentspetf' in data_type:
        x_real = pd.read_hdf(r'C:\Users\username\PycharmProjects\GPS\stock475.h5', 'df').values
        x_real = torch.tensor(x_real).unsqueeze(0).float()
    elif 'experimentsp' in data_type:
        x_real = pd.read_hdf(r'C:\Users\username\PycharmProjects\GPS\stock463.h5', 'df').values
        x_real = torch.tensor(x_real).unsqueeze(0).float()
    elif 'experiment400' in data_type:
        x_real = pd.read_hdf(r'C:\Users\username\PycharmProjects\GPS\stock401.h5','df').values
        x_real = torch.tensor(x_real).unsqueeze(0).float()
    elif 'experiment_' in data_type:
        x_real = pd.read_hdf(r'C:\Users\username\PycharmProjects\GPS\stock50.h5','df').values
        x_real = torch.tensor(x_real).unsqueeze(0).float()
    elif data_type == 'HestonMany2Old':
        x_real = simulate_Heston_corr_many2old(**data_params)
    elif data_type in ['HestonMany2', 'HestonMany3']:
        ds_title = data_type + '_' + str(data_params['N']) + '_' + str(data_params['T']) + '_' + str(
            data_params['Blocks'])
        if pt.exists(ds_title + '.h5'):
            x_real = pd.read_hdf(ds_title + '.h5').values
            x_real = torch.tensor(x_real).unsqueeze(0)
        else:
            x_real = simulate_Heston_corr_many2(**data_params)
        if not pt.exists(ds_title + '.h5'):
            tosave = x_real.squeeze(0).detach().cpu().numpy()
            pd.DataFrame(tosave).to_hdf(ds_title + '.h5', 'h5')
    elif data_type in ['jumps']:
        x_real = jumps_only(**data_params)
    elif data_type in ['NGARCH', 'NGARCH1']:
        x_real = simulate_ngarch(**data_params)
        x_real = x_real.permute(0, 2, 1)
    elif data_type in ['NGARCH2']:
        x_real, real_correlation_matrix = simulate_ngarchm(**data_params)
        x_real = x_real.permute(0, 2, 1)
    elif data_type == 'ECG':
        pipeline, x_real_raw, x_real = get_mit_arrythmia_dataset(**data_params)
    else:
        raise NotImplementedError('Dataset %s not valid' % data_type)
    assert x_real.shape[0] == 1
    print(x_real.size())
    if data_type in ['MGBM', 'MGBM1']:
        x_real1 = rolling_window(x_real[0], p + q, cum=False, norm=False, **data_params)
        x_real1_norm = rolling_window(x_real[0], p + q, cum=True, norm=True, **data_params)
        return x_real1, x_real, x_real1_norm
    else:
        if x_real[0].size(0)<q:
            x_real1 = x_real
        else:
            x_real1 = rolling_window(x_real[0], p + q, corr1=False)
        return x_real1, x_real, None






def skew_torch(x, dim=(0, 1), dropdims=True):
    mean_x = torch.mean(x, 1)
    xm =x.sub(mean_x.repeat((x.size(1),1,1)).permute(1,0,2))
    x_3 = torch.pow(xm, 3).mean(dim=1)
    x_std_3 = torch.pow(torch.std(x,dim=1),3)
    skew = x_3 / (x_std_3+1e-6)
    del x_3,x_std_3,xm,mean_x
    return skew


def kurtosis_torch(x, dim=(0, 1), excess=True, dropdims=True):
    mean_x = torch.mean(x, 1)
    xm =x.sub(mean_x.repeat((x.size(1),1,1)).permute(1,0,2))
    x_4 = torch.pow(xm, 4).mean(dim=1)
    x_var2 = torch.pow(torch.var(x,dim=1),2)
    kurtosis = x_4 / (x_var2+1e-6)
    if excess:
        kurtosis = kurtosis - 3
    del x_4,xm, x_var2, mean_x

    return kurtosis
