from sympy import *
from itertools import groupby
from operator import itemgetter
from scipy.integrate import dblquad
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
tqdm.pandas()

def hKSD(p_mean_val=0, p_var_val=1, q_mean_val=0, q_var_val=1, h1_val=1, h2_val=1):

    sp_x = -(x-p_mean_val)/p_var_val
    sp_y = -(y-p_mean_val)/p_var_val

    q_x = exp( -(x-q_mean_val)**2 / (2*q_var_val) ) / sqrt(2*pi*q_var_val)
    q_y = exp( -(y-q_mean_val)**2 / (2*q_var_val) ) / sqrt(2*pi*q_var_val)

    k1 = exp( - (x-y)**2 / (2 * h1_val) )
    k2 = exp( - (x-y)**2 / (2 * h2_val) )

    d_k1_x = Derivative(k1, x, evaluate=True)
    d_k1_y = Derivative(k1, y, evaluate=True)
    d_k2_x = Derivative(k2, x, evaluate=True)
    d_k2_y = Derivative(k2, y, evaluate=True)

    d_k1_x_y = Derivative(k1, x, y, evaluate=True)
    d_k2_x_y = Derivative(k2, x, y, evaluate=True)

    S_k1_k2_p_x_y = sp_x*k1*sp_y + d_k1_x*sp_y + sp_x*d_k2_y + d_k2_x_y
    S_k1_k2 = integrate(S_k1_k2_p_x_y * q_x * q_y, (x,-oo,oo), (y,-oo,oo))

    S_k1_x_y = sp_x*k1*sp_y + d_k1_x*sp_y + sp_x*d_k1_y + d_k1_x_y
    S_k1 = integrate(S_k1_x_y * q_x * q_y, (x,-oo,oo), (y,-oo,oo))

    S_k2_x_y = sp_x*k2*sp_y + d_k2_x*sp_y + sp_x*d_k2_y + d_k2_x_y
    S_k2 = integrate(S_k2_x_y * q_x * q_y, (x,-oo,oo), (y,-oo,oo))

    return S_k1.evalf(), S_k2.evalf(), S_k1_k2.evalf()


var_name = 'sigma_q'


if __name__ == '__main__':

    rc_fonts = {
        "text.usetex": True,
        'mathtext.default': 'regular',
        'text.latex.preamble': r'\usepackage{amsfonts}'
    }
    plt.rcParams.update(rc_fonts)

    os.makedirs('hKSD', exist_ok=True)

    x = Symbol('x', real=True)
    y = Symbol('y', real=True)

    p_var_val = 1
    h1_val = 1
    h2_val = 3

    x_axis_min = 0.7
    x_axis_max = 1.3
    x_axis_step_size = 0.02
    x_axis_step_num = int((x_axis_max-x_axis_min)/x_axis_step_size + 1)
    x_axis_vals = np.linspace(x_axis_min, x_axis_max, x_axis_step_num)

    df = pd.DataFrame(x_axis_vals, columns=[var_name])
    df['$h_1 = h_2 = 1$'], df['$h_1 = h_2 = 3$'], df['$h_1 = 1, h_2 = 3$'] = zip(*df[var_name].progress_apply(lambda x:
        hKSD(q_var_val=x, p_var_val=p_var_val, h1_val=h1_val, h2_val=h2_val)
    ))

    df_long = df.melt(id_vars=var_name)
    df_long.value = df_long.value.astype(float)

    fig, ax = plt.subplots()
    sns.lineplot(data=df_long, x=var_name, y='value', hue='variable', ax=ax).set(title='')

    ax.get_legend().set_title('')
    ax.set_yticks([0, 0.01, 0.02])
    ax.set_xlabel('$\sigma$')
    ax.set_ylabel('$\mathbb S^*_{k_1,k_2}$')
    # plt.ylabel("y", rotation=0)
    ax.set_ylabel('$\mathbb S^*_{k_1,k_2}$', rotation=0, labelpad=20)

    plt.savefig(os.path.join('hKSD', 'hKSD.png'))
    plt.clf()