import numpy as np
from scipy.optimize import nnls
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch

import sys


from experiments.Automaton import LifeLikeAutomaton2D, CoarseWrapper, calc_noise_sensitivity, calc_gowers_norm, est_spectral_weights_curve
from experiments.Automaton import CellularAutomaton2D
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd



def calc_noise_stability_batched(wrapper, rho, num_samples=100, batch_size=32):
    """
    Calculates Noise Stability: E[f(x) * f(y)] where y is rho-correlated to x.
    rho = 1 - 2*delta. 
    Range: rho=1 (perfect correlation), rho=0 (independent).
    #rho=0 is 50% noise, rho=1 is 0% noise
    """
    # Convert correlation rho to flip probability delta
    # rho = 1 - 2*delta  =>  2*delta = 1 - rho  => delta = (1 - rho) / 2
    delta = (1.0 - rho) / 2.0
    

    grid_size = wrapper.grid_size
    corr_accum = 0.0
    
    for start_idx in range(0, num_samples, batch_size):
        current_batch_size = min(batch_size, num_samples - start_idx)
        
        x = np.random.uniform(0, 1, (current_batch_size, grid_size, grid_size)) < 0.5
        noise = np.random.uniform(0, 1, (current_batch_size, grid_size, grid_size)) < delta
        x_prime = np.logical_xor(x, noise)
        
        combined_input = np.concatenate([x, x_prime], axis=0)
        _, combined_output = wrapper._forward_sample(combined_input)
        
        # Convert output to {-1, 1} for correlation
        if isinstance(combined_output, torch.Tensor):
            combined_output = combined_output.detach().cpu().numpy()
        
        # 0->1, 1->-1 (or vice versa, symmetric for correlation)
        polar_output = 1.0 - 2.0 * combined_output
        
        y, y_prime = np.split(polar_output, 2, axis=0)
        
        # Correlation: Mean(y * y_prime)
        batch_corr = np.mean(y * y_prime)
        corr_accum += batch_corr * current_batch_size

    return corr_accum / num_samples

def estimate_noise_stability(wrapper, num_rhos=20, num_samples=1000):
    """
    Approximates W_0, W_1, ..., W_k using Spectral Regression.
    
    Args:
        max_degree: The highest specific degree to estimate (e.g., 5).
                    Higher degrees are lumped into a 'residual' term.
        num_samples: Samples per rho value.
    
    Returns:
        Dictionary {0: W0, 1: W1, ..., 'tail': W_residual}
    """
   
    # This not same as NS_delta, but related
    deltas = np.geomspace(0.001, 0.9, num_rhos)
    
    
    rhos = 1.0 - deltas
    #round to 4 decimals to avoid numerical issues
    rhos = np.round(rhos, 4)

    rhos = np.sort(rhos)
    rhos = list(set(list(rhos)))
    rhos = np.sort(rhos) #returns 66 values for num_rhos=100 because of rounding
    deltas = (1.0 - rhos)/2.0
    
    stabilities = []
    for rho in rhos:
        val = calc_noise_stability_batched(wrapper, rho, num_samples)
        stabilities.append(val)

    results = {}
    for i, d in enumerate(deltas):
        d = (round(d*1000))/1000.0  #round to avoid numerical issues
        #convert stability back to noise sensitivity
        sensitivity = (1.0 - stabilities[i]) / 2.0
        results[f"NS_{d}"] = sensitivity

  
    return results

from tqdm import tqdm
def analyze_ca_dataframeRho(df, grid_size=32, save_path="rho_values.csv"):
    
    # Parameters
    results_list = []

    print(f"Starting analysis on {len(df)} rules...")
    print(f"Settings: Grid={grid_size}x{grid_size}, Samples=64")
    rule_dict = {}

    for index, row in tqdm(df.iterrows(), total=len(df), desc="Analyzing Rules"):
        try:
           
            #Hacky try except block for different automaton types
            try:
                rule_int = row['rule']
                #parse rule as list
                rule_int = [int(x) for x in rule_int.strip("[]").split(" ")]
                time_factor = int(row['time_factor'])
                
                automaton = CellularAutomaton2D(rule_int, grid_size)
                coarse = CoarseWrapper(automaton, time_factor, spatial_factor=1, only_output_coarse=True)
                
            except:
                rule_int = int(row["rule"])
                time_factor = int(row['time_factor'])
                
                automaton = LifeLikeAutomaton2D(rule_int, grid_size)
                coarse = CoarseWrapper(automaton, time_factor, spatial_factor=1, only_output_coarse=True)
                

            
            row_results = {}

        

            
            
            junta_weights = estimate_noise_stability(coarse,num_rhos=100, num_samples=512)
            row_results.update(junta_weights)
            
            results_list.append(row_results)
            


        except Exception as e:
            print(f"Error processing rule {row.get('rule', 'unknown')}: {e}")
            results_list.append({})

       

    # Merge and Save
    results_df = pd.DataFrame(results_list)
    results_df.index = df.index
    final_df = pd.concat([df, results_df], axis=1)
    
    final_df.to_csv(save_path, index=False)
    print(f"\nAnalysis complete. Results saved to {save_path}")
    
    return final_df

import time 
if __name__ == "__main__":
    
    paths=[] #ADD paths to pandas cached dataframes here
    for p in paths:
        
        df = pd.read_csv(p)
        save_path = p.replace(".csv", "_NS_Rho_512Samples.csv")
        df = analyze_ca_dataframeRho(df, grid_size=32, save_path=save_path)