import numpy as np
import time

import CAC_Ising

#import toolbox
import lib

import copy

import os

from scipy.ndimage import gaussian_filter1d
sigma = 10  # Standard deviation for Gaussian kernel




# parameters (same as previous)
do_RERUN = 1

T = 500*10
fMH = 20/(3000*10)

K = 10

nskip = int(T/10*0.5) #to discard

betal = [0.03,0.05,0.07,0.09,0.11]
lamb = 0.1

xi = 0.1
a = 1

kappal = np.linspace(1/4,6,10)*lamb

doal = [0,1]

PARAM_NAMES = ["beta","kappa","lamb","xi","a"]

#pt_device = "cuda"
pt_device = "cpu"
solvertype = 'MHCACm'


new_directory = './fig_effectiveTemperature'
if not os.path.exists(new_directory):
    os.makedirs(new_directory)
            



# load instance

if do_RERUN:
        
    N = 18
        
    name = 'WISHART_%d_0.78' % N
    #name = 'WISHART_%d_0.80' % N
        
    J    = lib.LoadInstance(name,N,0)
    H0   = lib.LoadOptimal(name,N)[0]
    eps0=np.mean(np.abs(J))
    
    
    res = []
    
    for ibeta, beta in enumerate(betal):
    
        print('------------- %0.2f percent completed' % float(ibeta/len(betal)))
    
        betaMH = beta / eps0
        
        Pth,Ho = lib.partitionf(J,N,betaMH,1)
            
        samples_KLmat = []
        
        for i, kappa in enumerate(kappal):
            
            print('---- %0.2f percent completed' % float(i/len(kappal)))
        
            x = np.log([beta,kappa,lamb,xi,a])
            
            betade = betaMH*kappa
            
            # do MHCAC
            samples_KL = []
            
            for doa in doal:
            #for doa in [0]:
                 
                if doa==0:
                    print('Without MH')
                else:
                    print('with MH')
                
                hyperparams = {'T': T,'doa':doa,'dosampling':1,'fMH':fMH}
            
                # run 
                solver = CAC_Ising.CAC(pt_device, N, J=J, H0=H0, solvertype=solvertype)
                solver.eps = eps0
                
                x_init = np.tile(np.expand_dims(x,1),[1,K])
                
                for idx, param_name in enumerate(PARAM_NAMES):
                    setattr(solver, param_name, np.exp(x_init[idx, :]))
                
                solver.init(K,PARAM_NAMES,hyperparams)
                
                Ps, E_opt, traj = solver.traj(H0,R_rec = K)
                
                # KL div calculation
                
                Hf = traj['E']
                
                
                samples_KL_ = []
                for k in range(K):
                
                    cHf = np.array(Hf)[nskip:,k]
                
                    #####
                
                    H0 = np.min(Ho)
                    bins = Ho
                    binedges = np.copy(bins).tolist()
                    binedges.append(binedges[-1]+1)
                    binedges = np.array(binedges) - 0.5
                
                    hist, binedges = np.histogram(cHf,bins=binedges)
                    Phist = hist/np.sum(hist)
                    
                    #debug
                    if 0:
                        plt.figure()
                        plt.plot(bins,Phist,'-',color=colors[0],label='MHCACm',lw=1.0)
                        plt.plot(bins,Pth,'-',color=colors[1],label='MHCACm',lw=1.0)
        
                    sKL = np.nansum((np.log(Phist/Pth)*Phist + np.log(Phist/Pth)*Phist)/2)
                    samples_KL_.append(sKL)
                    
                savefile = new_directory + f'/doa_{doa}_beta_{beta}_kappa_{kappa}_info.txt'
                np.savetxt(savefile, samples_KL_)
            
                del solver
            
                samples_KL.append(samples_KL_)
                
            samples_KLmat.append(samples_KL)
            
        res.append(samples_KLmat)
    
        
        
        
else:
        
    res2 = []
    
    for ibeta, beta in enumerate(betal):
    
        samples_KLmat = []
        
        for i, kappa in enumerate(kappal):
            
            samples_KL = []
            
            for doa in doal:
                
                savefile = new_directory + f'/doa_{doa}_beta_{beta}_kappa_{kappa}_info.txt'
                samples_KL_ = np.loadtxt(savefile)
            
                samples_KL.append(samples_KL_)
                
            samples_KLmat.append(samples_KL)
        
        res2.append(samples_KLmat)
        
    res = res2
    
    
    
#plot

fs = 14
      
new_directory = './fig_effectiveTemperature'
if not os.path.exists(new_directory):
    os.makedirs(new_directory)

import matplotlib.pyplot as plt
plt.style.use('plot_style.txt')
import copy

plt.figure(figsize=(12, 4))

ax = plt.subplot(1,3,2)

colors=['r','b']
#lines = [':','--','-','-','-']
lines = ['-','-','-','-',':']

count = 0

kappa_opt = []

for beta, samples_KLmat, line in zip(betal,res,lines):
    
    samples_KLmat=np.array(samples_KLmat)
    
    mV = np.mean(samples_KLmat,2)
    sV = np.std(samples_KLmat,2)*1.96/np.sqrt(K)
    
    if count==3 or count==4:
    #if 1:        
        
        plt.errorbar(kappal,mV[:,0],yerr=sV[:,0],label=r'CACm, $\beta=%0.2f$' % beta,color=colors[0],linestyle=line)
        plt.errorbar(kappal,mV[:,1],yerr=sV[:,1],label=r'MHCACm, $\beta=%0.2f$' % beta,color=colors[1],linestyle=line)

    kappa_opt.append(np.min(mV,0).tolist())

    count+=1

#plt.plot([lamb,lamb],ylim,'--r',label='lamb')

plt.xlabel(r'$\kappa = \tilde{\beta} / \beta $',fontsize=fs)
plt.ylabel('KL divergence',fontsize=fs)
plt.legend(ncols=2,fontsize=10,loc='lower left')

plt.xlim(np.min(kappal),np.max(kappal))
plt.ylim(.25,1.5)

ax = plt.gca()
ax.spines["top"].set_visible(True)
ax.spines["right"].set_visible(True)

ax.set_title('b')

if 1: #add plot a here from fig_BoltzmannDistribution.py

        
    colors = ['r','g','b','m']
    
    ax = plt.subplot(1,3,1)
        
    for doa in doal:
    
        if doa==1:
            bins = np.load('./fig_BoltzmannDistribution/bins_MH.npy')
            Phist = np.load('./fig_BoltzmannDistribution/Phist_MH.npy')
            Ho = np.load('./fig_BoltzmannDistribution/Ho_MH.npy')
            Pth = np.load('./fig_BoltzmannDistribution/Pth_MH.npy')
            
            y_smooth = gaussian_filter1d(Phist, sigma)
            ax.plot(bins,y_smooth,'-',color=colors[2],label='MHCACm (proposed)',lw=1.0)
        
            y_smooth = gaussian_filter1d(Pth, sigma)
            ax.plot(Ho,y_smooth,'--',label='Exact Gibbs distribution',color=colors[1],lw=1.0)
        
        else:
            bins = np.load('./fig_BoltzmannDistribution/bins_noMH.npy')
            Phist = np.load('./fig_BoltzmannDistribution/Phist_noMH.npy')
            Ho = np.load('./fig_BoltzmannDistribution/Ho_noMH.npy')
            Pth = np.load('./fig_BoltzmannDistribution/Pth_noMH.npy')
        
            y_smooth = gaussian_filter1d(Phist, sigma)
            ax.plot(bins,y_smooth,'-',color=colors[0],label='CACm',lw=1.0)
        
        ax.set_ylim(3*10**(-4),0.5*10**(-2))
        ax.set_xlim(-0.4*10**(5),0.3*10**(5))
        
        ax.set_xlabel('$H$',fontsize=fs)
        ax.set_ylabel('$P(H)$',fontsize=fs)
        
        ax.set_yscale('log')
        
        ax.set_title('a')
        
        ax.spines["top"].set_visible(True)
        ax.spines["right"].set_visible(True)
        
        ax.legend(fontsize=fs)



ax = plt.subplot(1,3,3)

plt.plot(np.array(kappa_opt)[:,0]*np.array(betal),betal,'d-r',label='CACm')
plt.plot(np.array(kappa_opt)[:,1]*np.array(betal),betal,'o:b',label='MHCACm')

xv = np.array([np.min(betal),np.max(betal)])/2
yv = np.array([np.min(betal),np.max(betal)])
plt.plot(xv,yv,'--k')

plt.ylim(np.min(betal),np.max(betal))
plt.xlim(np.min(xv),np.max(xv))

ax.set_xlabel(r'effective inv. temperature $\tilde{\beta}^*$',fontsize=fs)
ax.set_ylabel(r'inv. temperature $\beta$',fontsize=fs)

ax.set_title('c')

ax.spines["top"].set_visible(True)
ax.spines["right"].set_visible(True)

plt.legend(fontsize=fs)

plt.tight_layout()


plt.savefig('./fig_effectiveTemperature/fig_effectiveTemperature.pdf', format='pdf')
plt.savefig('./fig_effectiveTemperature/fig_effectiveTemperature.png')


