import numpy as np
import time

import CAC_Ising

#import toolbox
import lib

import copy

import os

from scipy.ndimage import gaussian_filter1d
sigma = 10  # Standard deviation for Gaussian kernel


# parameters (same as previous)

T = 9000*10
fMH = 10/(9000*10)
K = 10

nskip = 100 #to discard

beta = 0.07
lamb = 0.1
kappa = lamb

doal = [0,1]

x = np.log([beta,kappa,lamb,0.1,1.0,0.1])
PARAM_NAMES = ["beta","kappa","lamb","xi","a","gamma"]
solvertype = 'MHCACm'

pt_device = "cuda"




# load instance

N = 18
    
name = 'WISHART_%d_0.78' % N
#name = 'WISHART_%d_0.80' % N
    
J    = lib.LoadInstance(name,N,0)
H0   = lib.LoadOptimal(name,N)[0]
eps0=np.mean(np.abs(J))

# brute force

betaMH = beta / eps0
betade = betaMH*kappa

Pth,Ho = lib.partitionf(J,N,betaMH,1)
    
# do MHCAC
samples_KL = []

for doa in doal:
#for doa in [0]:
     
    if doa==0:
        print('Without MH')
    else:
        print('with MH')
    
    hyperparams = {'T': T,'doa':doa,'dosampling':1,'fMH':fMH}

    # run 
    solver = CAC_Ising.CAC(pt_device, N, J=J, H0=H0, solvertype=solvertype)
    solver.eps = eps0
    
    x_init = np.tile(np.expand_dims(x,1),[1,K])
    
    for idx, param_name in enumerate(PARAM_NAMES):
        setattr(solver, param_name, np.exp(x_init[idx, :]))
    
    solver.init(K,PARAM_NAMES,hyperparams)
    
    Ps, E_opt, traj = solver.traj(H0,R_rec = K)
    
    # KL div calculation
    
    Hf = traj['E']
    
    
    samples_KL_ = []
    for k in range(10):
    
        cHf = np.array(Hf)[nskip:,k]
    
        #####
    
        H0 = np.min(Ho)
        bins = Ho
        binedges = np.copy(bins).tolist()
        binedges.append(binedges[-1]+1)
        binedges = np.array(binedges) - 0.5
        #binedges = np.linspace(-40,30,50)
    
        hist, binedges = np.histogram(cHf,bins=binedges)
        #hist, binedges = np.histogram(cHf/eps0,bins=binedges)
        Phist = hist/np.sum(hist)
    
        sKL = np.nansum((np.log(Phist/Pth)*Phist + np.log(Phist/Pth)*Phist)/2)
        samples_KL_.append(sKL)
    
    new_directory = './fig_BoltzmannDistribution'
    if not os.path.exists(new_directory):
        os.makedirs(new_directory)
                
    if doa == 1:
        np.save('./fig_BoltzmannDistribution/bins_MH',bins)
        np.save('./fig_BoltzmannDistribution/Phist_MH',Phist)
        np.save('./fig_BoltzmannDistribution/Ho_MH',Ho)
        np.save('./fig_BoltzmannDistribution/Pth_MH',Pth)
    else:
        np.save('./fig_BoltzmannDistribution/bins_noMH',bins)
        np.save('./fig_BoltzmannDistribution/Phist_noMH',Phist)
        np.save('./fig_BoltzmannDistribution/Ho_noMH',Ho)
        np.save('./fig_BoltzmannDistribution/Pth_noMH',Pth)
        
    del solver

    samples_KL.append(samples_KL_)
    
#plot
          
import matplotlib.pyplot as plt
plt.style.use('plot_style.txt')
import copy

#plt.plot(Hf[:,0])

colors = ['r','g','b','m']

plt.figure(figsize=(6, 4))
ax = plt.gca()
    
for doa in doal:

    if doa==1:
        bins = np.load('./fig_BoltzmannDistribution/bins_MH.npy')
        Phist = np.load('./fig_BoltzmannDistribution/Phist_MH.npy')
        Ho = np.load('./fig_BoltzmannDistribution/Ho_MH.npy')
        Pth = np.load('./fig_BoltzmannDistribution/Pth_MH.npy')
        
        y_smooth = gaussian_filter1d(Phist, sigma)
        ax.plot(bins,y_smooth,'-',color=colors[2],label='MHCACm (proposed)',lw=1.0)
    
        y_smooth = gaussian_filter1d(Pth, sigma)
        ax.plot(Ho,y_smooth,'--',label='Exact Gibbs distribution',color=colors[1],lw=1.0)
    
    else:
        bins = np.load('./fig_BoltzmannDistribution/bins_noMH.npy')
        Phist = np.load('./fig_BoltzmannDistribution/Phist_noMH.npy')
        Ho = np.load('./fig_BoltzmannDistribution/Ho_noMH.npy')
        Pth = np.load('./fig_BoltzmannDistribution/Pth_noMH.npy')
    
        y_smooth = gaussian_filter1d(Phist, sigma)
        ax.plot(bins,y_smooth,'-',color=colors[0],label='CACm',lw=1.0)
    
    ax.set_ylim(3*10**(-4),0.5*10**(-2))
    ax.set_xlim(-0.4*10**(5),0.3*10**(5))
    
    ax.set_xlabel('$H$')
    ax.set_ylabel('$P(H)$')
    
    ax.set_yscale('log')
    
    ax.set_title('a')
    
    ax.spines["top"].set_visible(True)
    ax.spines["right"].set_visible(True)
    
    ax.legend()
