import sys
sys.path.append('/opt/app/hard_label_manifolds')

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import utils
import glob
import json
import pandas as pd
import concurrent.futures

from fnmatch import filter
from collections import defaultdict
from tqdm import tqdm
from scipy.stats import norm as normal
from scipy.stats import rv_histogram
from datetime import datetime


def get_time_stamp():
    date_object = datetime.now()
    return date_object.strftime('%m%d%y-%H%M%S')


def mi(num_train, num_manifold, mu=1, c=0.5, c1=0.5, c2=0.5, d=1, eps=0., seed=9):
    sigma_n = c * (d ** (1/4))
    # print(f"mu: {mu}, c: {c}, d: {d} -> sigma_n={sigma_n}")
    sigma_r = c1 * (d ** (1/4))
    # print(f"mu: {mu}, c: {c}, d: {d} -> sigma_r={sigma_r}")
    err_tol = 0.1
    
    if eps == 0.:
        sigma_ = sigma_n
    else:
        sigma_ = sigma_r
  
    mu_k = mu
    assert mu_k == 1 # schmidt constraint
    
    def joint(gk, xk):
        rv = normal(gk*mu_k, sigma_) # exact pdf
        
        res = rv.pdf(xk)  
        return res
    
    def marginal_g(g_k):
        np.random.seed(seed)
        rv = normal(g_k*mu_k, sigma_)  # exact pdf
        # rv = data_pdf(g_k*mu_k, sigma_)

        a, b = rv.ppf(0.00001), rv.ppf(0.99999)
        dx = (b-a) / (num_manifold // 2)

        xk = np.sort(np.random.uniform(low=a, high=b-dx, size=num_manifold // 2))
        xk = xk[xk > 0]

        rv1 = rv.pdf(xk)
        marginal_ = (rv1 * dx)
    
        return np.sum(marginal_)
    
    def marginal_m(xk):
        res = joint(1, xk) + joint(-1, xk)
        return res
    
    # gradient prob mass
    all_g = marginal_g(1) + marginal_g(-1)
    assert np.allclose([all_g], [1.], rtol=err_tol), f"Missing probability mass, got {all_g}"

    all_m = 0
    mi = 0
        
    for g_k in [-1, 1]:
        np.random.seed(seed)
        rv = normal(g_k*mu_k, sigma_)
        # rv = data_pdf(g_k*mu_k, sigma_)
        
        a, b = rv.ppf(0.00001), rv.ppf(0.99999)
        dx = (b-a) / (num_manifold // 2)
        
        x_k = np.sort(np.random.uniform(low=a, high=b-dx, size=num_manifold // 2))
        x_k = x_k[x_k > 0] 

        # manifold prob mass
        all_m = np.sum(marginal_m(x_k) * dx)
        assert np.allclose([all_m], [1.], rtol=err_tol), f"Missing probability mass, got {all_m}"

        marginal_prod = marginal_g(g_k) * marginal_m(x_k)
        mi_case = (joint(g_k, x_k) * np.log(joint(g_k, x_k) / marginal_prod) * dx)
        mi_case = np.sum(mi_case * 2)
                
        mi += mi_case
    
    return mi * d


def schmidt_(epsilon, d, c2):
    if epsilon <= 0.25 / (d ** 0.25):
        return 1
    else: # epsilon <= 0.25:
        res = int(np.ceil(c2 * (epsilon ** 2) * np.sqrt(d)))
        # print(f"{c2} * {epsilon ** 2} * sqrt({d}) = {res}")

        return res
    

def do_run(pack):
    eps, d, c, c1, c2, mu, seed = pack
    n = schmidt_(eps, d, c2)
    try:
        mi_res = mi(num_train=n, num_manifold=n, d=d, eps=eps, c1=c1, c=c, c2=c2, mu=mu, seed=seed) / d
    except Exception as e:
        print(e)
    
    name = f"{eps}-{d}_seed-{seed}"
    return {'epsilon': eps,
            'd': d, 
            'mi': mi_res, 
            'c2': c2, 
            'seed': seed,
            'name': name}



futures = []
c = 1
c1 = 1
mu = 1
eps_ = [0.090, 0.180, 0.250]
c2_ = [10**5, 10**6]
seeds = np.arange(0, 10, 1)

with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
    # for eps in eps_:
    print("Synthesizing jobs...")
    for c2 in c2_:
        for eps in eps_:
            for seed in tqdm(seeds):
                for d in list(np.arange(75, 1005, 5)):
                    pack = (eps, d, c, c1, c2, mu, seed)
                    future = executor.submit(do_run, pack)
                    futures.append(future)
            
            
df = pd.DataFrame()
df_path = os.path.join(f'figures/mi_{get_time_stamp()}.csv')

in_data = {}
with tqdm(total=len(futures)) as pb:
    for future in concurrent.futures.as_completed(futures):
        result = future.result()
        
        in_data[result['name']] = result
        pb.update(1)
        
        in_df = pd.DataFrame.from_dict(in_data, orient="index")
        df = pd.concat([df, in_df])
        df.to_csv(df_path)
        del in_df
        in_data = {}

