import CAC_Ising
import lib

import numpy as np
import time
    
import os

import matplotlib.pyplot as plt

datapath = './../Data'

###########################################

N = 800
T = 10
i0 = 0 #instance id

R = 2000

pt_device = 'cuda'
#pt_device = 'cpu'

fMH = 0.1
#fMH = 1.0

ni = {800:21, 1000:4, 2000:21}
n = ni[N]


####################################################
#SOLVER

if 1:
    solvertype = 'MHCACm'
    
    PARAM_NAMES = ["beta","kappa","lamb","xi","gamma","a"]
    x = np.log([0.1,0.1,1.0,0.1,1.0,1.0])
    
    if i0==0:
        #wrong format
        #x = np.log([-2.623056121636534,-2.492811128986292,-0.8223195053974548,-4.234032365404217,-0.24728422167909597,0.3250854187202053])
        #T=3000
        
        #correct 
        #x = np.log([0.07258071, 0.08267722, 0.43941125, 0.01449383, 0.78091871, 1.38414887])
        #T=3000
        
        #from Sam
        #x = np.log([0.0610020241933274,0.06808693124191877,0.26294841017516646,0.09664222668785089,0.5158864574053653,1.1905591350133589])
        #T=3000
        
        #tests
        x = np.log([0.0610020241933274,0.06808693124191877,0.26294841017516646,0.01449383,0.5158864574053653,1.1905591350133589])
        T=3000
        
        
    if i0 == 1:
        
        x = np.log([0.04845284629425895,0.0651546464407786,0.2544987881015054,0.010536516774508809,0.5224580962091915,1.3608951397100029])
        T=6000
    hyperparams = {'T': T,'doa':1,'dosampling':0,'fMH':fMH}

else:
    
    solvertype = 'CACm'
    PARAM_NAMES = ["beta","lamb","xi","gamma","a"]
    x = np.log([0.1,0.2,0.1,1.0,0.8])
    hyperparams = {'T': T}
 

###########################################






def LoadInstance(i):
    i+=1
    path = datapath + f'/GSET_{N}/GSET_{N}_{n}_{i}'

    file = open ( path , 'r')
    wmat = np.array([[float(num) for num in line.split(' ')] for line in file ])
    
    w = np.zeros((N,N))
    for l in range(len(wmat[:,0])):
        w[int(wmat[l,0]-1),int(wmat[l,1]-1)] = wmat[l,2]

    w = w + w.T

    return w
    
def LoadOptimalC(N,i,w):
    n = ni[N]
    path = datapath + f'/GSET_{N}/GSET_{N}_{n}_SOL'
    file = open ( path , 'r')
    Cv = np.array([[float(num) for num in line.split(' ')] for line in file ])
    C0 = float(Cv[i])
    H0 = -4*C0 - np.sum(w)
    H0 = H0/2
    
    return H0


def gen_problem():
    
    J = LoadInstance(i0)
    H0 = LoadOptimalC(N,i0,J)
        
    eps0 = np.mean(np.abs(J))
    
    prec = 1
    H0 = np.floor(H0/prec)
        
    #setup solver

    solver = CAC_Ising.CAC(pt_device, N, J=J, H0=H0, solvertype=solvertype, precGS = prec)
    solver.eps = eps0

    return solver, H0


solver, E0 = gen_problem()

#setup solver

x_init = np.tile(np.expand_dims(x, 1),[1,R])

for idx, param_name in enumerate(PARAM_NAMES):
    setattr(solver, param_name, np.exp(x_init[idx, :]))


solver.init(R,PARAM_NAMES,hyperparams)


Ps, E_opt = solver.traj(E0)

TTS = T*np.log(1-0.99)/np.log(1-Ps)

#print(np.min(E_opt))
#print(E0)
print(Ps,TTS)