#!/usr/bin/env python
# coding: utf-8

# In[1]:


import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
import os
import json
import itertools

import matplotlib
matplotlib.use('Agg')  # Non-interactive backend


# In[2]:


"""-----------------------------------------------------------------
NOISY BAYESIAN ISING
-------------------------------------------------------------------
* Variational mean field q(σ) with **REINFORCE** so noisy energy feedback
 moves the Bernoulli parameters `spin_probs`.
* Block Metropolis **pseudo marginal** kernel (averaged noisy Fourier
 energy) brings each sample close to the Boltzmann distribution even
 when single spin ΔE ≲ noise.
* **Inspection helpers** let you snapshot and visualise any full spin
 configuration produced during optimisation.
-------------------------------------------------------------------"""

# ----------------------------------------------------------------------
#  NOISE / ENERGY HELPERS
# ----------------------------------------------------------------------


# In[163]:


# ----------------------------------------------------------------------
#  CORE MODEL
# ----------------------------------------------------------------------

class VariationalNoisyIsing:
   """Bayesian Ising model driven only by *noisy* Fourier energies.

   Parameters
   ----------
   grid_shape : int | tuple
       If int → treated as N spins on a √N × √N square.
   noise_level : float
       Relative (multiplicative) noise σ/|E| in the Fourier energy.
   anneal_factor : float
       β ← β·anneal_factor after each outer step.
   device : str
       'cuda' if available else 'cpu'.
   """

   # ----------------------------- init ----------------------------
   def __init__(
       self,
       grid_shape=(32, 32),  # M x K grid size
       noise_level: float = 0.03,  # Noise level, 0.05 means 5% noise
       beta_min: float = 3,    # Range of Beta - beta_min
       beta_max: float = 3,     # Range of Beta - beta_max
       delta_beta: float = 0.0,
       anneal_factor: float = 1.01,  # scale Beta by anneal factor after certain steps
       energy_evals_per_config = 1, # Default value for energy_evals_per_config
       device: str = "cuda" if torch.cuda.is_available() else "cpu",
   ):
       
       # Determining Grid Dimensions
       # isinstance checks at runtime whether the variable grid_shape is exactly of type int.
       # isinstance(x, T) returns True if x is an instance of type T (or a subclass).
       
       if isinstance(grid_shape, int):
           side = int(np.sqrt(grid_shape))
           if side * side != grid_shape:
               raise ValueError("grid_shape int must be perfect square")
           self.M, self.K = side, side
       else:
           self.M, self.K = grid_shape  # M xK grid

       self.N = self.M * self.K  # self.N holds the total number of spins.
       self.device = device
       self.noise_level = noise_level

       # variational Bernoulli parameters (prob σ=+1)
       #spin_probs is an 𝑀 × 𝐾  tensor of probabilities that spin is +1

       self.spin_probs = torch.full((self.M, self.K), 0.49,  # # shape: M rows × K columns. # fill every entry with 0.5
                                    device=device, requires_grad=True)  # track gradients for learning

       self.energy_evals_per_config = energy_evals_per_config

       # annealing schedule
       self.beta_min, self.beta_max = beta_min, beta_max
       self.current_beta = beta_min
       self.delta_beta = delta_beta
       self.anneal_factor = anneal_factor

       # optimiser for variational params. Uses Adam to update ONLY the spin_probs tensor.
       self.opt = torch.optim.Adam([self.spin_probs], lr=0.05)

       # placeholders for inspection
       self.latest_samples = None  # [B, M, K] # Will hold sampled spin configurations [B, M, K]
       self.latest_energies = None # [B] # Will hold their corresponding energies [B]
       self.energy_eval_count = 0


    # ---------------------- noisy Fourier energy ------------------

    # This method turns a spin configuration into a complex wave, reads out its zero-frequency (DC) 
    #Fourier amplitude, interprets that as an energy, and then injects controlled noise.

   def measure_energy_fourier(self, spins: torch.Tensor) -> torch.Tensor:
       B = spins.shape[0]  # batch size
       phases = torch.where(spins == 1, 0.0, math.pi)
       # torch.where(condition, x, y) returns a new tensor that, element‐wise, takes x (here 0.0) wherever condition is True and y (here math.pi) 
       # wherever condition is False. So for any spin = +1, spins == 1 is True → phase = 0.0; for spin = –1, it’s False → phase = π.
       
       cwf = torch.exp(1j * phases) # complex wave form. At each lattice site: +1 spin → exp(i·0) = 1+0j ; –1 spin → exp(i·π) = –1+0j
       
       fft_res = torch.fft.fftshift(torch.fft.fft2(cwf), dim=(-2, -1)) 
       # fft2(cwf) takes the 2D FFT over the last two dims (M × K grid). We have a batch axis B (multiple samples) followed by 
       # our 2D grid axes (M, K). We want an FFT on each individual 𝑀 × 𝐾  grid, not across the batch dimension.
       # fftshift moves the zero-frequency (DC) component to the center of the spectrum. 
       
       dc = torch.abs(fft_res[:, self.M // 2, self.K // 2]) ** 2 # grab the center element [M//2, K//2] for each batch. 
                                                                 #dc has shape [B]—one energy proxy per sample.
       base_E = -(0 + ((dc - self.N) / (2 * self.N)))   #-dc / self.N # normalizes the energy per site. base_E is a float tensor of shape [B]
       noise = torch.normal(0.0, self.noise_level * torch.abs(base_E)) # Noise has a gaussian distribution with mean 0.0
       #noise = torch.rand(-self.noise_level * torch.abs(base_E), self.noise_level * torch.abs(base_E))
       # Increment energy count PER SPIN
       self.energy_eval_count += B  
       #print(base_E + noise) 
       return base_E + noise  # shape [B]

    # ---------------------- block Metropolis ---------------------

    
    
   @torch.no_grad()
    # ---------------------- ELBO with REINFORCE -------------------
    # Maximizing the Evidence Lower BOund (ELBO) using a REINFORCE gradient estimator for the discrete spins.
   def variational_step(self, n_samples: int = 20):

       
       self.n_samples = n_samples
       self.opt.zero_grad() # clear out the optimizer’s accumulated gradients on self.spin_probs
       
       # spins = (2 * torch.bernoulli(self.spin_probs.expand(n_samples, -1, -1)) - 1).detach()
       # self.latest_samples = spins
       
       all_energies = []
       all_spins = []
    
       # Process in batches
       batch_size = 2000
       for i in range(0, n_samples, batch_size):
           current_batch_size = min(batch_size, n_samples - i)
        
           # Sample batch
           spins_batch = (2 * torch.bernoulli(
           self.spin_probs.expand(current_batch_size, -1, -1)) - 1).detach()

           # Compute energy K times per configuration
           energies_accum = torch.zeros(current_batch_size, device=spins_batch.device)
           for _ in range(self.energy_evals_per_config):
               energies_accum += self.measure_energy_fourier(spins_batch)
               averaged_energies = energies_accum / self.energy_evals_per_config
        
           # Compute energies for batch
           #energies_batch = self.measure_energy_fourier(spins_batch)
        
           all_spins.append(spins_batch)
           all_energies.append(averaged_energies)

       # Concatenate results
       spins = torch.cat(all_spins, dim=0)
       energies = torch.cat(all_energies, dim=0) 
       #print('var_energy:', energies)
       self.latest_samples = spins
       #energies = self.measure_energy_fourier(spins)  # [B]
       self.latest_energies = energies.detach()

       #baseline = energies.mean().detach()
       #baseline=0
       
       eps = 1e-8 # eps prevents log(0)
       
       logq_grad = (spins - (2 * self.spin_probs - 1)) / (2 * self.spin_probs * (1 - self.spin_probs) + eps)

       sigma2 = (0 * self.noise_level * energies.abs().detach()) ** 2
       energy_with_correction = (self.current_beta * energies) - 0.5 * (self.current_beta ** 2) * sigma2 # * energies.pow(2)

       # Baseline (mean of energy term)
       baseline = energy_with_correction.mean().detach()
       
       # REINFORCE term  
       reinforce_grad_term =  (
               ((energy_with_correction.detach().unsqueeze(-1).unsqueeze(-1) - baseline) * logq_grad).mean(dim=0))
       
       # reinforce_grad_term = self.current_beta * (
       #         ((energies.detach().unsqueeze(-1).unsqueeze(-1)) * logq_grad).mean(dim=0))


       # entropy of q
       entropy_grad_term = torch.log((self.spin_probs + eps) / (1 - self.spin_probs + eps))


       grad_estimate = reinforce_grad_term + entropy_grad_term
       #print('grad_estimate', grad_estimate)
       
       self.spin_probs.grad = grad_estimate
       
       self.opt.step()
       with torch.no_grad():
           self.spin_probs.clamp_(0.01, 0.99)

        
    # ---------------------- helpers & inspection ------------------
   def magnetisation(self):
       return (2 * self.spin_probs - 1).mean().item()

   def anneal_step(self):
       #self.current_beta = min(self.current_beta * self.anneal_factor, self.beta_max)
       self.current_beta = min(self.current_beta + self.delta_beta, self.beta_max)
       # self.current_beta = beta_min * (1 + scale * math.log(1 + t))

   # > Inspection API
   def inspect_sample(self, idx: int = 0):
       """Return the last MH equilibrated sample‐lattice *and* its energy."""
       if self.latest_samples is None:
           raise RuntimeError("No samples drawn yet – call variational_step() first.")
       cfg = self.latest_samples[idx].cpu().numpy()
       E  = self.latest_energies[idx].item() if self.latest_energies is not None else None
       return cfg, E

   def save_energies(self, step, energies, beta, magnetization, save_dir="energy_logs"):
    os.makedirs(save_dir, exist_ok=True)
    
    # Save energies
    np.save(os.path.join(save_dir, f"step_{step:04d}_energies.npy"), energies.detach().cpu().numpy())
    
    # Save metadata
    metadata = {
        "step": step,
        "beta": beta,
        "magnetization": magnetization
    }
    with open(os.path.join(save_dir, f"step_{step:04d}_meta.json"), "w") as f:
        json.dump(metadata, f, indent=2)
        

    # ---------------------- outer annealing loop ------------------
   def run(self, n_steps: int = 30, steps_per_T: int = 10):
    self.n_steps = n_steps
    self.steps_per_T = steps_per_T 
    history = {"step": [], "β": [], "⟨m⟩": []}
    print("Running noisy Bayesian annealing (REINFORCE)…")
    for step in tqdm(range(n_steps), desc="Anneal"):
        # 1) Variational updates at fixed beta
        
        for _ in range(self.steps_per_T):
            self.variational_step()

        # 2) Record metrics
        history["step"].append(step)
        history["β"].append(self.current_beta)
        history["⟨m⟩"].append(self.magnetisation())

        #self.save_energies(step, self.latest_energies, self.current_beta, self.magnetisation())

        if step % 1 == 0:
            print(f"Step {step:>3}, β = {self.current_beta:.3f}, ⟨m⟩ = {self.magnetisation():.3f}")

        # 3) Early exit if fully ordered
        if abs(history["⟨m⟩"][-1]) > 0.98:
            print(f"Ordered state reached at step {step} (β={self.current_beta:.2f}).")
            break

        # 4) Increase β for next iteration
        self.anneal_step()
        #self.current_beta = 0.5 #25
        #self.current_beta = self.beta_min * (1 + math.log(1 + step))


    return history


# In[164]:


# ----------------------------------------------------------------------
#  DRIVER + SIMPLE PLOT
# ----------------------------------------------------------------------

def plot(history):
    plt.plot(history["step"], history["⟨m⟩"], label="⟨m⟩")
    plt.xlabel("anneal step")
    plt.ylabel("magnetisation")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.savefig("magnetisation_plot.png")  # ✅ Instead of plt.show()
    print("✅ Plot saved to 'magnetisation_plot.png'")


# In[165]:


def save_history_to_csv(history, model, filename=None):

    # Ensure the folder exists
    folder = (f"n_sample_effect/{model.M}_by_{model.K}/noise{model.noise_level}/numexpt{model.energy_evals_per_config}")
    os.makedirs(folder, exist_ok=True)
    
    if filename is None:
        filename = (f"history_{model.M}_by_{model.K}_"
                    f"nsteps{model.n_steps}_stepsperT{model.steps_per_T}_nsamples{model.n_samples}_"
                    f"beta{model.current_beta}_noise{model.noise_level}_numexpt{model.energy_evals_per_config}.csv")

    # Full path including folder
    filepath = os.path.join(folder, filename)
    
    renamed_history = {
        "step": history["step"],
        "current_beta": history["β"],
        "magnetisation": history["⟨m⟩"]
    }

    

    df = pd.DataFrame(renamed_history)
    df.to_csv(filepath, index=False)
    print(f"✅ History saved to {filename}")



def run_experiment():
    torch.manual_seed(42)
    np.random.seed(42)

    # Define parameter ranges for noise levels and beta values
    noise_levels = [0.03]  # Add more noise levels as needed
    #beta_values = [3]
    beta_values = [0.5, 0.7, 0.85, 1.0, 1.015, 1.03, 1.06, 1.09, 1.12, 1.15, 1.18, 1.21, 1.24, 1.27, 1.3, 1.36, 1.4, 1.5, 1.65, 1.8, 2.0, 2.2, 2.6, 3.0]  # Add more beta values as needed
    num_energy_evals_values = [1]
    n_steps_values = [200]
    steps_per_T_values = [1]
    n_samples_values = [4000] #[240000, 200000, 160000, 120000, 80000, 40000, 20000, 10000, 4000]

    # Get all parameter combinations, including the new noise_levels and beta_values
    param_combinations = list(itertools.product(
        steps_per_T_values, 
        n_steps_values, 
        n_samples_values,
        noise_levels,
        beta_values,
        num_energy_evals_values
    ))

    print(f"Running {len(param_combinations)} experiments...")

    for i, (steps_per_T, n_steps, n_samples, noise_level, beta_value, num_energy_evals_values) in enumerate(param_combinations):
        print(f"\n--- Experiment {i+1}/{len(param_combinations)} ---")
        print(f"Parameters: steps_per_T={steps_per_T}, n_steps={n_steps}, "
              f"n_samples={n_samples}, noise_level={noise_level}, beta={beta_value}, num_energy_evals={num_energy_evals_values}")

        # Initialize the model with the current noise level and beta value
        model = VariationalNoisyIsing(
            noise_level=noise_level,
            beta_min=beta_value,
            beta_max=beta_value,
            energy_evals_per_config=num_energy_evals_values,
            delta_beta=0.0,  # Can be adjusted if needed
        )
        
        # Modify the model's variational_step to use the current n_samples
        original_variational_step = model.variational_step

        def modified_variational_step():
            return original_variational_step(n_samples=n_samples)

        model.variational_step = modified_variational_step
        hist = model.run(n_steps=n_steps, steps_per_T=steps_per_T)

        # Save history to CSV or a file
        save_history_to_csv(hist, model)

        # Demonstrate inspection
        cfg, energy = model.inspect_sample()
        print("Total energy evaluations:", model.energy_eval_count)


# In[167]:


try:
    run_experiment()
except RuntimeError as e:
    print("Runtime error:", e)


# In[184]:


print('hi_5')

