import neal
import neal.simulated_annealing as sa

from pysa.sa import Solver
from pysa.ising import get_energy

import dimod
import numpy as np

import generate_dWPE

import itertools

from multiprocessing import Pool

def par_run_pt(x):
    out = []
    for i in range(x[2]):
        #print(i)
        out.append(x[0].run_pt(x[1][i]))
    return out

def generate_sk(N):
    # Initialize linear biases (fields) to zero (h vector)
    h = np.zeros(N)
    
    # Generate a random symmetric coupling matrix from N(0, 1)
    J = np.sign(np.random.normal(0, 1, size=(N, N)))
    J = np.triu(J) - np.diag(np.diag(J))
    J = J + J.T
    
    return h, J

def ising_hamiltonian(config, h, J):
    config = np.array(config)
    energy = -np.dot(h, config) - 0.5 * np.dot(config, np.dot(J, config))
    return energy

def brute_force(h, J):
    N = len(h)  # Number of spins
    
    # Generate all 2^N possible configurations of spins (each spin is either -1 or +1)
    spin_configs = itertools.product([-1, 1], repeat=N)
    
    # Initialize variables to store the best (lowest-energy) configuration
    ground_state = None
    ground_energy = float('inf')
    
    # Iterate through all possible configurations
    for config in spin_configs:
        energy = ising_hamiltonian(config, h, J)
        
        # If we find a new lower energy, update the ground state
        if energy < ground_energy:
            ground_energy = energy
            ground_state = config
    
    return ground_state, ground_energy

class PT:

    
    def __init__(self, N, J, H0, prec):
            
        self.J = J
        self.N = np.shape(J)[0]
        h = np.zeros(self.N)
        self.H0 = H0
        self.prec = prec
        
        #default parameteres
        self.bqm = dimod.BQM(h, -J, 'SPIN')
        beta_start, beta_end = neal.default_beta_range(self.bqm)
        
        self.max_temp = 1/beta_start
        self.min_temp = 1/beta_end
        self.splt = 0.25
        
    #function for annealing schedule (set to linear for now)
    def schedule(self, tau):
        return tau
    
    def run_pt(self, min_temp, max_temp, splt, T, K):
        
        n_sweeps = int(T**(splt/(1 + splt)))
        n_vars = self.N
        n_replicas = int(T**(1/(1 + splt)))
        n_reads = K
        
        #print(n_sweeps, n_replicas)
        norm = 1/np.sqrt(np.average(self.J**2))
        
        # Get solver
        #solver = Solver(problem=-self.J*norm, problem_type='ising', float_type='float32')
        solver = Solver(problem=-self.J, problem_type='ising', float_type='float32')
        
        # Apply Metropolis
        res_1 = solver.metropolis_update(
            num_sweeps=n_sweeps,
            num_reads=n_reads,
            num_replicas=n_replicas,
            update_strategy='sequential',
            min_temp=min_temp,
            max_temp=max_temp,
            initialize_strategy='random',
            recompute_energy=False,
            sort_output_temps=True,
            parallel=True,  # True by default
            verbose=False)
        
        energies = np.floor(res_1['best_energy'].values/self.prec)
        solutions = res_1['best_state'].values
        
        return energies, solutions
    
            
    def get_default_temp(self):
        return self.max_temp, self.min_temp
    
if __name__ == "__main__":
    
    T = 1000     # T is the number of sweeps!
    K = 50
    N = 11
    
    i = int(100000* np.random.rand())
    alpha = float("0.8")
    M = int(N*alpha)
    
    # unbiased wishart for tests
    if 1:
        data = {}
        data['D_WPE'] = 1 # 3
        data['R_WPE'] = -1 # 6
        data['bias'] = 0.0
        J, H0, gs = generate_dWPE.gen_dWPE(i, N, M, data['D_WPE'], data['R_WPE'])
        eps0 = np.mean(np.abs(J))
        #prec = 10**(-6) #precision for GS energy
        prec = 1 #precision for GS energy
        H0 = np.floor(H0/prec)
        h = np.zeros(N)
    else:
        h, J = generate_sk(N)
        sig, H0 = brute_force(h, J)
        prec = 1
    
    solver = PT(N, J, H0, prec)

    max_temp, min_temp = solver.get_default_temp()
    splt = 0.25

    energies, solutions = solver.run_pt(min_temp, max_temp, splt, T, K)

    print(energies-H0)