#!/usr/bin/env python
# coding: utf-8

# In[62]:


import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
import os
import json
import itertools

import matplotlib
matplotlib.use('Agg')  # Non-interactive backend


# In[63]:


"""-----------------------------------------------------------------
NOISY BAYESIAN ISING
-------------------------------------------------------------------
* Variational mean field q(σ) with **REINFORCE** so noisy energy feedback
 moves the Bernoulli parameters `spin_probs`.
* Block Metropolis **pseudo marginal** kernel (averaged noisy Fourier
 energy) brings each sample close to the Boltzmann distribution even
 when single spin ΔE ≲ noise.
* **Inspection helpers** let you snapshot and visualise any full spin
 configuration produced during optimisation.
-------------------------------------------------------------------"""

# ----------------------------------------------------------------------
#  NOISE / ENERGY HELPERS
# ----------------------------------------------------------------------


# In[67]:


# ----------------------------------------------------------------------
#  CORE MODEL
# ----------------------------------------------------------------------

class VariationalNoisyIsing:
   """Bayesian Ising model driven only by *noisy* Fourier energies.

   Parameters
   ----------
   grid_shape : int | tuple
       If int → treated as N spins on a √N × √N square.
   noise_level : float
       Relative (multiplicative) noise σ/|E| in the Fourier energy.
   anneal_factor : float
       β ← β·anneal_factor after each outer step.
   device : str
       'cuda' if available else 'cpu'.
   """

   # ----------------------------- init ----------------------------
   def __init__(
       self,
       grid_shape=(32, 32),  # M x K grid size
       noise_level: float = 0.03,  # Noise level, 0.05 means 5% noise
       beta_min: float = 1.05,    # Range of Beta - beta_min
       beta_max: float = 1.05,     # Range of Beta - beta_max
       delta_beta: float = 0.0,
       anneal_factor: float = 1.01,  # scale Beta by anneal factor after certain steps
       energy_evals_per_config = 1, # Default value for energy_evals_per_config
       device: str = "cuda" if torch.cuda.is_available() else "cpu",
   ):
       
       # Determining Grid Dimensions
       # isinstance checks at runtime whether the variable grid_shape is exactly of type int.
       # isinstance(x, T) returns True if x is an instance of type T (or a subclass).
       
       if isinstance(grid_shape, int):
           side = int(np.sqrt(grid_shape))
           if side * side != grid_shape:
               raise ValueError("grid_shape int must be perfect square")
           self.M, self.K = side, side
       else:
           self.M, self.K = grid_shape  # M xK grid

       self.N = self.M * self.K  # self.N holds the total number of spins.
       self.device = device
       self.noise_level = noise_level

       # variational Bernoulli parameters (prob σ=+1)
       #spin_probs is an 𝑀 × 𝐾  tensor of probabilities that spin is +1

       self.spin_probs = torch.full((self.M, self.K), 0.49,  # # shape: M rows × K columns. # fill every entry with 0.5
                                    device=device, requires_grad=True)  # track gradients for learning

       self.energy_evals_per_config = energy_evals_per_config 
       
       # annealing schedule
       self.beta_min, self.beta_max = beta_min, beta_max
       self.current_beta = beta_min
       self.delta_beta = delta_beta
       self.anneal_factor = anneal_factor

       # optimiser for variational params. Uses Adam to update ONLY the spin_probs tensor.
       self.opt = torch.optim.Adam([self.spin_probs], lr=0.5)

       # placeholders for inspection
       self.latest_samples = None  # [B, M, K] # Will hold sampled spin configurations [B, M, K]
       self.latest_energies = None # [B] # Will hold their corresponding energies [B]
       self.energy_eval_count = 0


    # ---------------------- noisy Fourier energy ------------------

    # This method turns a spin configuration into a complex wave, reads out its zero-frequency (DC) 
    #Fourier amplitude, interprets that as an energy, and then injects controlled noise.

   def measure_energy_fourier(self, spins: torch.Tensor) -> torch.Tensor:
       B = spins.shape[0]  # batch size
    
       # Compute nearest-neighbor Ising Hamiltonian for each sample in the batch
       energies = []
    
       for b in range(B):
           spin_grid = spins[b]  # Shape: [M, K]
           J = 1.0
           energy = 0.0
        
           # Horizontal nearest-neighbor interactions
           energy -= J * torch.sum(spin_grid * torch.roll(spin_grid, shifts=-1, dims=1))
        
           # Vertical nearest-neighbor interactions  
           energy -= J * torch.sum(spin_grid * torch.roll(spin_grid, shifts=-1, dims=0))
        
           energies.append(energy)
    
       # Convert to tensor
       base_E = torch.stack(energies)  # Shape: [B]
    
       #print(base_E)
       # Add noise (keeping the same noise model as original)
       noise = torch.normal(0.0, self.noise_level * torch.abs(base_E))
    
       # Increment energy count PER SPIN (keeping the same counting as original)
       self.energy_eval_count += B
    
       return base_E + noise  # shape [B] 


    # ---------------------- block Metropolis ---------------------

    
    
   @torch.no_grad()
    # ---------------------- ELBO with REINFORCE -------------------
    # Maximizing the Evidence Lower BOund (ELBO) using a REINFORCE gradient estimator for the discrete spins.
   def variational_step(self, n_samples: int = 20):
       
       self.n_samples = n_samples
       self.opt.zero_grad() # clear out the optimizer’s accumulated gradients on self.spin_probs
       
       # spins = (2 * torch.bernoulli(self.spin_probs.expand(n_samples, -1, -1)) - 1).detach()
       # self.latest_samples = spins
       
       all_energies = []
       all_spins = []
    
       # Process in batches
       batch_size = 2000
       for i in range(0, n_samples, batch_size):
           current_batch_size = min(batch_size, n_samples - i)
        
           # Sample batch
           spins_batch = (2 * torch.bernoulli(
           self.spin_probs.expand(current_batch_size, -1, -1)) - 1).detach()

           # Compute energy K times per configuration
           energies_accum = torch.zeros(current_batch_size, device=spins_batch.device)
           for _ in range(self.energy_evals_per_config):
               energies_accum += self.measure_energy_fourier(spins_batch)
               averaged_energies = energies_accum / self.energy_evals_per_config
        
           # Compute energies for batch
           #energies_batch = self.measure_energy_fourier(spins_batch)
        
           all_spins.append(spins_batch)
           all_energies.append(averaged_energies)

       # Concatenate results
       spins = torch.cat(all_spins, dim=0)
       energies = torch.cat(all_energies, dim=0) 
       #print('var_energy:', energies)
       self.latest_samples = spins
       #energies = self.measure_energy_fourier(spins)  # [B]
       self.latest_energies = energies.detach()

       #baseline = energies.mean().detach()
       #baseline=0
       
       eps = 1e-8 # eps prevents log(0)
       
       logq_grad = (spins - (2 * self.spin_probs - 1)) / (2 * self.spin_probs * (1 - self.spin_probs) + eps)

       sigma2 = (self.noise_level * energies.abs().detach()) ** 2
       energy_with_correction = (self.current_beta * energies) - 0.5 * (self.current_beta ** 2) * sigma2 #* energies.pow(2)

       #energy_with_correction = (energies) - 0.5 * (self.current_beta) * sigma2

       # Baseline (mean of energy term)
       baseline = energy_with_correction.mean().detach()
       
       # REINFORCE term  
       reinforce_grad_term =  (
               ((energy_with_correction.detach().unsqueeze(-1).unsqueeze(-1) - baseline) * logq_grad).mean(dim=0))


       # entropy of q
       entropy_grad_term = torch.log((self.spin_probs + eps) / (1 - self.spin_probs + eps))

       #entropy_grad_term = (1/self.current_beta) * (torch.log((self.spin_probs + eps) / (1 - self.spin_probs + eps)))

       grad_estimate = reinforce_grad_term + entropy_grad_term
       #print('grad_estimate', grad_estimate)
       
       self.spin_probs.grad = grad_estimate
       
       self.opt.step()
       with torch.no_grad():
           self.spin_probs.clamp_(0.01, 0.99)

        
    # ---------------------- helpers & inspection ------------------
   def magnetisation(self):
       return (2 * self.spin_probs - 1).mean().item()

   def anneal_step(self):
       #self.current_beta = min(self.current_beta * self.anneal_factor, self.beta_max)
       self.current_beta = min(self.current_beta + self.delta_beta, self.beta_max)
       # self.current_beta = beta_min * (1 + scale * math.log(1 + t))

   # > Inspection API
   def inspect_sample(self, idx: int = 0):
       """Return the last MH equilibrated sample‐lattice *and* its energy."""
       if self.latest_samples is None:
           raise RuntimeError("No samples drawn yet – call variational_step() first.")
       cfg = self.latest_samples[idx].cpu().numpy()
       E  = self.latest_energies[idx].item() if self.latest_energies is not None else None
       return cfg, E

   def save_energies(self, step, energies, beta, magnetization, save_dir="energy_logs"):
    os.makedirs(save_dir, exist_ok=True)
    
    # Save energies
    np.save(os.path.join(save_dir, f"step_{step:04d}_energies.npy"), energies.detach().cpu().numpy())
    
    # Save metadata
    metadata = {
        "step": step,
        "beta": beta,
        "magnetization": magnetization
    }
    with open(os.path.join(save_dir, f"step_{step:04d}_meta.json"), "w") as f:
        json.dump(metadata, f, indent=2)
        

    # ---------------------- outer annealing loop ------------------
   def run(self, n_steps: int = 30, steps_per_T: int = 10):
    self.n_steps = n_steps
    self.steps_per_T = steps_per_T 
    history = {"step": [], "β": [], "⟨m⟩": []}
    print("Running noisy Bayesian annealing (REINFORCE)…")
    for step in tqdm(range(n_steps), desc="Anneal"):
        # 1) Variational updates at fixed beta
        
        for _ in range(self.steps_per_T):
            self.variational_step()

        # 2) Record metrics
        history["step"].append(step)
        history["β"].append(self.current_beta)
        history["⟨m⟩"].append(self.magnetisation())

        #self.save_energies(step, self.latest_energies, self.current_beta, self.magnetisation())

        if step % 1 == 0:
            print(f"Step {step:>3}, β = {self.current_beta:.3f}, ⟨m⟩ = {self.magnetisation():.3f}")

        # 3) Early exit if fully ordered
        if abs(history["⟨m⟩"][-1]) > 0.98:
            print(f"Ordered state reached at step {step} (β={self.current_beta:.2f}).")
            break

        # 4) Increase β for next iteration
        self.anneal_step()
        #self.current_beta = 0.5 #25
        #self.current_beta = self.beta_min * (1 + math.log(1 + step))


    return history


# In[68]:


def save_history_to_csv(history, model, filename=None):

    # Ensure the folder exists
    folder = (f"quenching/{model.M}_by_{model.K}/noise{model.noise_level}/numexpt{model.energy_evals_per_config}")
    os.makedirs(folder, exist_ok=True)
    
    if filename is None:
        filename = (f"history_{model.M}_by_{model.K}_"
                    f"nsteps{model.n_steps}_stepsperT{model.steps_per_T}_nsamples{model.n_samples}_"
                    f"beta{model.current_beta}_noise{model.noise_level}_numexpt{model.energy_evals_per_config}.csv")

    # Full path including folder
    filepath = os.path.join(folder, filename)
    
    renamed_history = {
        "step": history["step"],
        "current_beta": history["β"],
        "magnetisation": history["⟨m⟩"]
    }

    

    df = pd.DataFrame(renamed_history)
    df.to_csv(filepath, index=False)
    print(f"✅ History saved to {filename}")


# In[75]:


def run_experiment():

    torch.manual_seed(69)
    np.random.seed(69)

    # Define parameter ranges for noise levels and beta values
    noise_levels = [0.06]  # Add more noise levels as needed
    beta_values = [0.25]
    #beta_values = [0.1, 0.2, 0.30, 0.4, 0.42, 0.43, 0.44, 0.45, 0.47, 0.5, 0.55, 0.6, 0.65, 0.7, 1.0, 2.0]  # Add more beta values as needed
    num_energy_evals_values = [1]
    n_steps_values = [3000]
    steps_per_T_values = [1]
    n_samples_values = [3000]

    # Get all parameter combinations, including the new noise_levels and beta_values
    param_combinations = list(itertools.product(
        steps_per_T_values, 
        n_steps_values, 
        n_samples_values,
        noise_levels,
        beta_values,
        num_energy_evals_values
    ))

    print(f"Running {len(param_combinations)} experiments...")

    for i, (steps_per_T, n_steps, n_samples, noise_level, beta_value, num_energy_evals_values) in enumerate(param_combinations):
        print(f"\n--- Experiment {i+1}/{len(param_combinations)} ---")
        print(f"Parameters: steps_per_T={steps_per_T}, n_steps={n_steps}, "
              f"n_samples={n_samples}, noise_level={noise_level}, beta={beta_value}, num_energy_evals={num_energy_evals_values}")

        # 🔁 Re-seed BEFORE model initialization
        #torch.manual_seed(0)
        #np.random.seed(0)

        # Initialize the model with the current noise level and beta value
        model = VariationalNoisyIsing(
            noise_level=noise_level,
            beta_min=beta_value,
            beta_max=beta_value,
            energy_evals_per_config=num_energy_evals_values,
            delta_beta=0.0,  # Can be adjusted if needed
        )
        
        # Modify the model's variational_step to use the current n_samples
        original_variational_step = model.variational_step

        def modified_variational_step():
            return original_variational_step(n_samples=n_samples)

        model.variational_step = modified_variational_step
        hist = model.run(n_steps=n_steps, steps_per_T=steps_per_T)

        # Save history to CSV or a file
        save_history_to_csv(hist, model)

        # Demonstrate inspection
        cfg, energy = model.inspect_sample()
        print("Total energy evaluations:", model.energy_eval_count)


# In[ ]:


try:
    run_experiment()
except RuntimeError as e:
    print("Runtime error:", e)


# In[184]:


print('hi_5')


