import numpy as np
import qubovert as qv
import qubovert.utils
import generate_dWPE
import time
import sys
import os

import lib

#N = 60
#Nl = [60,100,140]
Nl = [140]

alpha = 0.8
D_WPE = 1
R_WPE = -1

T0 = 1000
T = T0
R = 1000

maxcount = 10**4

def tune(N,alpha,T,folder_name):
    
    start_flip_prob = 0.5
    end_flip_prob = 0.001
    
    Ps_any_list = []
    
    def get_Ps(T_start, T_end):
        
        tme = time.time()
        
        M = int(N*alpha)
        
        J, E0, gs = generate_dWPE.gen_dWPE(inst_i, N, M, D_WPE, R_WPE)
        eps0 = 10**(-6)
        
        
    
        H = qubovert.utils.matrix_to_qubo(-J)
        
        #Q = H.to_quso()
        time1 = time.time() - tme
        tme = time.time()
        
        
        temperature_range = qv.sim.anneal_temperature_range(H, start_flip_prob = start_flip_prob, end_flip_prob = end_flip_prob)
        #print(temperature_range)
        temperature_range = (T_start, T_end)
        #print(temperature_range)
        anneal_res = qv.sim.anneal_quso(H, num_anneals=R, anneal_duration = T, temperature_range = temperature_range)
        
        time2 = time.time()  -tme
        #print(anneal_res.best.value)
        
        n_found = [0]*D_WPE
        
        for res in anneal_res:
            s = [res.state[i] for i in range(N)]
            
            E = res.value
            
            #print(E, E0*2)
            s = np.array(s, dtype = float)
            #print(E, E0*2)
            #print(E,s)
            if(E <= E0*2 + eps0):
                
                found = False
                for i in range(D_WPE):
                    s = s/s[0]
                    #print(s - (gs[:,i]/gs[0,i]))
                    if( (s == (gs[:,i]/gs[0,i])).all() ):
                        n_found[i] += 1
                        found = True
                
                if(not found):
                    print("found spurious GS ", s)
        Ps_any = np.sum(n_found)/R	
        return Ps_any
    
    inst_i = 0
    
    T_end = 4.0
    T_start = 4.0
    
    thresh = 5
    
    f = 1.5
    
    print("optimizing initial temp")
    
    #optimize starting T
    prev_Ps = 0 
    increased = True
    
    T_start_best = 0
    
    count = 0
    while(increased):
        
        Ps = get_Ps(T_start, T_end)
        if(Ps > prev_Ps):
            T_start_best = T_start
        
        print(T_start, T_end, Ps)
        if(Ps < prev_Ps - thresh/R):
            increased = False
            T_start = T_start/f
        else:
            T_start = T_start*f
        
        prev_Ps = np.maximum(prev_Ps, Ps)
        
        count+=1
        if count>maxcount:
            break
        
    T_start = T_start_best
    
    print("optimizing final temp")
    
    #optimize ending T
    prev_Ps = 0 
    increased = True
    
    T_end_best = 0
    
    count = 0
    while(increased):
        
        Ps = get_Ps(T_start, T_end)
        
        if(Ps > prev_Ps):
            T_end_best = T_end
        
        print(T_start, T_end, Ps)
        if(Ps < prev_Ps - thresh/R):
            increased = False
            
        else:
            T_end = T_end*f
        
        
        prev_Ps = np.maximum(prev_Ps, Ps)
        
        count+=1
        if count>maxcount:
            break
        
    T_end = T_end_best
    
    print("optimizing N_step")
        
    
    T = 100
    #optimize N_step
    prev_Ps_rate = 0
    increased = True
    
    count = 0
    while(increased):
        
        Ps = get_Ps(T_start, T_end)
        Ps_rate = Ps/T
        print(T_start, T_end, T, Ps, Ps_rate)
        
        if(Ps_rate < prev_Ps_rate - thresh/(R*T) ):
            increased = False
            T = int(T//f)
        else:
            T = int(T*f)
        
        prev_Ps_rate = np.maximum(prev_Ps_rate, Ps_rate)
        #prev_Ps_rate = Ps_rate
        
        count+=1
        if count>maxcount:
            break
    
    
    Ps = get_Ps(T_start, T_end)
    TTS = T*np.log(0.01)/np.log(1 - Ps)
    print(Ps, TTS)
    
   
    alphatxt = '%0.2f' % alpha
    lib.save_to_file(folder_name, N, alphatxt, T0, Ps, [TTS], [T_start, T_end, T])
    
    #np.savetxt(folder_name + "/N_%i_alpha_%0.2f_info.txt" % (N, alpha), [Ps, TTS, T_start, T_end, T])

###########################################################

outpath = "SA"
   
if(not os.path.exists(outpath)):
    os.makedirs(outpath)

folder_name = lib.create_timestamped_folder(outpath)

for N in Nl:

    print(f'----------- Start N={N}')
    tune(N,alpha,T,folder_name)