import os
import numpy as np
import pandas as pd
from PIL import Image
import cv2
import pysindy as ps

# ─── CONFIG ────────────────────────────────────────────────────────────────
IMAGE_DIR = "/drive2/kt/SGD/aptos_train_data/train_images"
CSV_IN    = "/drive2/kt/SGD/aptos_train_data/train.csv"
CSV_OUT   = "/drive2/Kuntal/Synidi/output/aptos_thetas.csv"
THETA_NPY = "/drive2/Kuntal/Synidi/output/aptos_all_thetas.npy"  # Single NPY file
THETA_IDS = "/drive2/Kuntal/Synidi/output/aptos_theta_ids.npy"   # Corresponding IDs

# Ensure output directory exists
os.makedirs(os.path.dirname(CSV_OUT), exist_ok=True)

# PySINDy settings
LIBRARY      = ps.PolynomialLibrary(degree=3)
OPTIMIZER    = ps.STLSQ(threshold=0.1, normalize_columns=True)
SINDY_MODEL  = ps.SINDy(feature_library=LIBRARY, optimizer=OPTIMIZER)

# ─── PREPROCESSING FUNCTION ─────────────────────────────────────────────────
def preprocess_retinal_image(img_path, size=(512, 512)):
    try:
        # Load image
        img = Image.open(img_path).convert("RGB")
        img = img.resize(size)
        img_array = np.array(img)
        
        # Create mask for fundus region
        h, w = img_array.shape[:2]
        center = (w//2, h//2)
        radius = min(w, h)//2 - 10
        
        # Create circular mask
        mask = np.zeros((h, w), dtype=np.uint8)
        cv2.circle(mask, center, radius, 255, -1)
        
        # Apply mask to all channels
        r, g, b = cv2.split(img_array)
        r = cv2.bitwise_and(r, r, mask=mask)
        g = cv2.bitwise_and(g, g, mask=mask)
        b = cv2.bitwise_and(b, b, mask=mask)
        
        # Merge channels
        masked = cv2.merge([r, g, b])
        
        # Contrast enhancement within mask
        lab = cv2.cvtColor(masked, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        l = clahe.apply(l)
        enhanced = cv2.merge([l, a, b])
        enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2RGB)
        
        return enhanced, True
    except Exception as e:
        print(f"Error preprocessing image {img_path}: {e}")
        return None, False

# ─── LOAD METADATA ─────────────────────────────────────────────────────────
df = pd.read_csv(CSV_IN, dtype={"id_code": str})
theta_rows = []
all_thetas = []  # Will store all theta arrays
all_ids = []     # Will store corresponding image IDs
errors = []      # Will store model fit quality

# ─── PROCESS EACH IMAGE ────────────────────────────────────────────────────
for idx, row in df.iterrows():
    img_id = row["id_code"]
    label = row["diagnosis"]
    
    # Determine file extension (.png or .jpg)
    extensions = ['.png', '.jpg', '.jpeg']
    img_path = None
    for ext in extensions:
        test_path = os.path.join(IMAGE_DIR, f"{img_id}{ext}")
        if os.path.exists(test_path):
            img_path = test_path
            break
    
    if img_path is None:
        print(f"Could not find image for {img_id}")
        continue
    
    # Preprocess image
    processed_img, success = preprocess_retinal_image(img_path)
    if not success:
        continue
    
    # Radial sampling approach
    h, w = processed_img.shape[:2]
    center = (w//2, h//2)
    max_radius = min(w, h)//2 - 10
    
    # Create samples from center outward (radial trajectory)
    X = []
    for r in range(max_radius):
        # Sample points on a circle with radius r
        circle_points = []
        for angle in range(0, 360, 10):  # Sample every 10 degrees
            angle_rad = np.deg2rad(angle)
            x = int(center[0] + r * np.cos(angle_rad))
            y = int(center[1] + r * np.sin(angle_rad))
            if 0 <= x < w and 0 <= y < h:
                circle_points.append(processed_img[y, x])
        
        if circle_points:
            X.append(np.mean(circle_points, axis=0))  # Average RGB values on the circle
    
    # Convert to numpy array
    X = np.array(X)
    
    # Create time derivative (using finite difference)
    if len(X) < 2:
        print(f"Error: Not enough valid samples for {img_id}")
        continue
        
    dX = np.zeros_like(X)
    dX[:-1] = X[1:] - X[:-1]  # Forward difference
    dX[-1] = dX[-2]  # Use last valid derivative for the final point
    
    # Fit SINDy model
    try:
        SINDY_MODEL.fit(X, t=np.arange(len(X)), x_dot=dX)
        theta = SINDY_MODEL.coefficients()
        
        # Store the theta array and ID
        all_thetas.append(theta)
        all_ids.append(img_id)
        
        # Save model accuracy
        score = SINDY_MODEL.score(X, t=np.arange(len(X)), x_dot=dX)
        errors.append({"id_code": img_id, "score": score})
        
        # Flatten and record for CSV
        flat_theta = theta.flatten()
        row_out = {"id_code": img_id, "diagnosis": label}
        for i, val in enumerate(flat_theta):
            row_out[f"theta_{i}"] = val
        theta_rows.append(row_out)
        
        print(f"Processed {img_id} ({idx+1}/{len(df)}) - Score: {score:.4f} - Theta shape: {theta.shape}")
    except Exception as e:
        print(f"Error fitting SINDy model for {img_id}: {e}")
        continue

# ─── SAVE ALL THETAS AS SINGLE NPY FILE ─────────────────────────────────────
# Convert lists to numpy arrays
all_thetas_array = np.array(all_thetas)
all_ids_array = np.array(all_ids)

# Save to single NPY files
np.save(THETA_NPY, all_thetas_array)
np.save(THETA_IDS, all_ids_array)

# ─── WRITE CSV WITH FLATTENED THETAS ───────────────────────────────────────
df_theta = pd.DataFrame(theta_rows)
df_theta.to_csv(CSV_OUT, index=False)

# ─── SAVE ERROR METRICS ───────────────────────────────────────────────────
df_errors = pd.DataFrame(errors)
df_errors.to_csv(os.path.join(os.path.dirname(CSV_OUT), "theta_errors.csv"), index=False)

# ─── PRINT SUMMARY ─────────────────────────────────────────────────────────
print("=" * 50)
print(f"Processing complete!")
print(f"Total images processed: {len(all_thetas)}/{len(df)}")
print(f"Theta shape per image: {all_thetas_array.shape[1:]}")
print(f"Total thetas array shape: {all_thetas_array.shape}")
print(f"Average reconstruction score: {np.mean([e['score'] for e in errors]):.4f}")
print("=" * 50)
print(f"Saved outputs:")
print(f"1. All thetas: {THETA_NPY}")
print(f"2. Image IDs: {THETA_IDS}")
print(f"3. CSV with flattened thetas: {CSV_OUT}")
print(f"4. Error metrics: {os.path.dirname(CSV_OUT)}/theta_errors.csv")