#!/usr/bin/env python
# coding: utf-8

# In[1]:


import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
#from convert2WFC import transformation
import itertools
import math
import time
import os
#from find_max import find_max

np.random.seed(42)


# In[2]:


lattice = np.load('../spin_configs/lattice_32x32.npy')


# In[3]:


Nx, Ny = lattice.shape


# In[4]:


#Nx,Ny = lattice_n.shape

# plt.figure()
# plt.colorbar(plt.imshow(lattice,cmap = 'gray'))

# plt.grid(False)
# plt.show()


# In[5]:


def logarithmic_cooling_fixed_range(t, total_steps, T_start, T_end, scale):
    """
    Logarithmic cooling schedule scaled to a fixed temperature range.
    
    Args:
        t (int): The current time step or iteration number.
        total_steps (int): The total number of steps in the annealing process.
        T_start (float): The initial temperature.
        T_end (float): The final temperature.
        
    Returns:
        float: The temperature at step t.
    """
    if t >= total_steps:
        return T_end
    
    # Normalized progress (0 to 1) based on a logarithmic scale
    log_progress = math.log(1 + scale*t) / math.log(1 + scale*total_steps)
    
    # Linear interpolation between T_start and T_end
    # This formula maps the logarithmic curve to the desired temperature range.
    return T_start - (T_start - T_end) * log_progress


# In[18]:


def nn_ising_hamiltonian(spins, noise_level):
    """Reference implementation using np.roll."""
    J = 1.0
    spins = np.asarray(spins, dtype=np.float32)
    M, K = spins.shape
    energy = 0.0
    energy -= J * np.sum(spins * np.roll(spins, shift=-1, axis=1))  # horizontal
    energy -= J * np.sum(spins * np.roll(spins, shift=-1, axis=0))  # vertical
    
    if noise_level > 0:
        sigma = noise_level * abs(energy)
        epsilon = np.random.normal(0.0, sigma)
        energy += epsilon
        
    #energy = energy / (M*K)
    return float(energy)


# In[44]:


def metropolis(spin_arr, times, T_start, T_end, initial_energy, noise, numexpt):
    """
    Corrected Metropolis algorithm for 2D Ising model
    """
    spin_arr = spin_arr.copy()
    net_spins = np.zeros(times-1, dtype=np.float64)
    curr_BJ = np.zeros(times-1, dtype=np.float64)
    net_energy = np.zeros(times-1, dtype=np.float64)
    delta_E = np.zeros(times-1, dtype=np.float64)
    
    Lx, Ly = spin_arr.shape
    
    # Calculate initial energy properly (without noise for consistency)
    current_energies = []
    for _ in range(numexpt):
        current_energies.append(nn_ising_hamiltonian(spin_arr, noise))
    current_energy = np.mean(current_energies)
    #current_energy = nn_ising_hamiltonian(spin_arr, noise)
    
    time = 0
    while time < times - 1:
        curr_beta = 1 / logarithmic_cooling_fixed_range(time, times-1, T_start, T_end, 1.0)
        curr_BJ[time] = curr_beta
        
        # Choose random spin to flip
        x = np.random.randint(0, Lx)
        y = np.random.randint(0, Ly)
        
        spin_flip = spin_arr.copy()
        spin_flip[x, y] = -spin_arr[x, y]

        flip_energies = []
        for _ in range(numexpt):
            flip_energies.append(nn_ising_hamiltonian(spin_flip, noise))
        flip_energy = np.mean(flip_energies)
        
        #flip_energy = nn_ising_hamiltonian(spin_flip, noise)
        
        # Calculate neighbors with periodic boundary conditions
        # t = spin_arr[(x - 1) % Lx, y]  # top neighbor
        # b = spin_arr[(x + 1) % Lx, y]  # bottom neighbor  
        # l = spin_arr[x, (y - 1) % Ly]  # left neighbor
        # r = spin_arr[x, (y + 1) % Ly]  # right neighbor
        
        # Calculate energy difference for flipping spin at (x,y)
        #dE = 2.0 * spin_arr[x, y] * (t + b + l + r)
        dE = flip_energy - current_energy
        
        # Metropolis acceptance criterion
        accept = False
        if dE <= 0:
            accept = True
        elif np.random.random() < np.exp(-curr_beta * dE):
            accept = True
        
        # Apply the move if accepted
        if accept:
            spin_arr[x, y] = spin_flip[x, y]  # Flip the spin
            current_energy += dE  # Update energy
            
        # Record observables
        net_spins[time] = spin_arr.sum()
        net_energy[time] = current_energy
        delta_E[time] = dE
        
        if (time % 30000 == 0):
            avg_spin = net_spins[time] / (Lx * Ly)
            print(f"Time step: {time:6d}, T={1/curr_beta:.3f}, β={curr_beta:.3f}, dE={dE:6.1f}, <s>={avg_spin:.4f}")
        
        time += 1
        
    return net_spins, net_energy, spin_arr, delta_E, curr_BJ


# # Run Experiments

# In[54]:


# Define parameter grids
times = 2500000 #0000
noise_levels = [0.01]
numexpt_values = [1]

#Define beta pairs (start_beta, end_beta)
#BJ_pairs = [(0.08, 1.2)]
#BJ_pairs = [(0.5, 0.5), (0.7, 0.7), (0.85, 0.85), (1, 1), (1.015,1.015), (1.03, 1.03), (1.06, 1.06), (1.09,1.09), (1.12,1.12), (1.15,1.15),
            #(1.18,1.18), (1.21,1.21), (1.24,1.24), (1.27,1.27), (1.3,1.3), (1.36,1.36), (1.4, 1.4), (1.5, 1.5), (1.65, 1.65), (1.8, 1.8),
            #(2.0, 2.0), (2.2, 2.2), (2.6, 2.6), (3.0, 3.0), (3.5, 3.5) ]

BJ_pairs = [
    (0.100, 0.100), (0.200, 0.200), (0.300, 0.300), (0.380, 0.380), 
    (0.400, 0.400), (0.420, 0.420), (0.430, 0.430), (0.440, 0.440), 
    (0.450, 0.450), (0.470, 0.470), (0.500, 0.500), (0.550, 0.550), 
    (0.600, 0.600), (0.650, 0.650), (0.700, 0.700), (1.000, 1.000), (2.000, 2.000)
]

# Convert to temperature pairs using T = 1 / β
temperature_pairs = [(1 / BJ_start, 1 / BJ_end) for BJ_start, BJ_end in BJ_pairs]

# Initialize storage for summary statistics

# Loop over all combinations
for noise, numexpt, (T_initial, T_final) in itertools.product(noise_levels, numexpt_values, temperature_pairs):
    
    key = f"noise={noise}_numexpt={numexpt}_Tinit={T_initial:.3f}_Tfinal={T_final:.3f}"
    print(f"Running: {key}")

    
    # Run the Metropolis simulation
    spins, energies, final_lattice, delta_E, BJ = metropolis(lattice.copy(), times, T_initial, T_final, 
                                                             nn_ising_hamiltonian(lattice, noise), noise, numexpt)

    # Ensure the folder exists
    folder = (f"quenching/{Nx}_by_{Ny}/noise{noise}/numexpt{numexpt}")
    os.makedirs(folder, exist_ok=True)
    
    # Construct filename
    filename = f"SA{Nx}x{Ny}_noise{noise}_numexpt{numexpt}_beta{(1/T_initial):.3f}.csv"

    # Full path including folder
    filepath = os.path.join(folder, filename)

    # Prepare data
    steps = np.arange(times - 1)
    beta_values = np.array(BJ)
    magnetization = np.array(spins) / (Nx * Ny)
    
    # beta_mean = beta_values[-100000:].mean()
    # magnetization_mean = magnetization[-100000:].mean()
    # # Store summary statistics for plotting later
    # summary_data.append({
    #     "beta": round(1 / T_initial, 3),
    #     "noise": noise,
    #     "numexpt": numexpt,
    #     "beta_mean": round(beta_mean, 6),
    #     "magnetization_mean": round(magnetization_mean, 6),
    # })

    # Stack columns: step, beta, magnetization
    history = np.column_stack((steps, beta_values, magnetization))

    # Save to CSV
    np.savetxt(filepath, history, delimiter=",", header="Step,Beta,Magnetization", comments='')


# In[34]:


print ('End of SA')


# In[58]:


# Plot magnetization vs steps
# plt.figure(figsize=(10, 6))
# plt.plot(1/summary_df['beta_mean'], abs(summary_df['magnetization_mean']), marker='o', linestyle='-', color='blue')
# plt.title('Magnetization vs beta')
# plt.xlabel('Steps')
# plt.ylabel('Magnetization')
# plt.grid(True)
# plt.tight_layout()
# plt.show()
# plt.close()



# In[46]:


# def get_spin_energy(lattice, BJs):
#     ms = np.zeros(len(BJs))
#     E_means = np.zeros(len(BJs))
#     E_stds = np.zeros(len(BJs))
#     N = lattice.shape[0]
#     for i, bj in enumerate(BJs):
#         #spins, energies = metropolis(lattice, 1000000, bj, get_energy(lattice))
#         # Run the Metropolis simulation
#         T_start = 1 / bj
#         T_end = 1 / bj
#         spins, energies, final_lattice, delta_E, BJ = metropolis(lattice.copy(), 400000, T_start, T_end,
#                                                              nn_ising_hamiltonian(lattice, 0), 0, 1)
#         ms[i] = spins[-40000:].mean()/N**2
#         E_means[i] = energies[-40000:].mean()
#         E_stds[i] = energies[-40000:].std()
#     return ms, E_means, E_stds
    
# BJs = np.arange(0.2, 1, 0.05)
# ms_n, E_means_n, E_stds_n = get_spin_energy(lattice, BJs)


# In[52]:


# plt.figure(figsize=(8,5))
# plt.plot(1/BJs, abs(ms_n), 'o--', label='75% of spins started negative')
# #plt.plot(1/BJs, abs(ms_p), 'o--', label='75% of spins started positive')
# plt.xlabel(r'$\left(\frac{k}{J}\right)T$')
# plt.ylabel(r'$\bar{m}$')
# plt.xlim([0, 10])
# plt.legend(facecolor='white', framealpha=1)
# plt.show()


# In[ ]:




