import os
import numpy as np
import pandas as pd
from PIL import Image
import cv2
import pysindy as ps
import time
from tqdm import tqdm

# ─── CONFIG ────────────────────────────────────────────────────────────────
IMAGE_DIR = "/drive2/kt/SGD/messidor2/IMAGES"
CSV_IN    = "/drive2/kt/SGD/messidor2/messidor_data.csv"
OUTPUT_DIR = "/drive2/Kuntal/Pysindy-experiment/M2-output"
CSV_OUT   = f"{OUTPUT_DIR}/messidor2_thetas.csv"
THETA_NPY = f"{OUTPUT_DIR}/messidor2_all_thetas.npy"
THETA_IDS = f"{OUTPUT_DIR}/messidor2_theta_ids.npy"  
ERROR_CSV = f"{OUTPUT_DIR}/messidor2_theta_errors.csv"

# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)

# PySINDy settings
LIBRARY      = ps.PolynomialLibrary(degree=3)
OPTIMIZER    = ps.STLSQ(threshold=0.1, normalize_columns=True)
SINDY_MODEL  = ps.SINDy(feature_library=LIBRARY, optimizer=OPTIMIZER)

# ─── PREPROCESSING FUNCTION ─────────────────────────────────────────────────
def preprocess_retinal_image(img_path, size=(512, 512)):
    try:
        # Load image - handle various formats (PNG, JPG)
        img = Image.open(img_path)
        img = img.convert("RGB")  # Ensure RGB format
        img = img.resize(size)
        img_array = np.array(img)
        
        # Create mask for fundus region
        h, w = img_array.shape[:2]
        center = (w//2, h//2)
        radius = min(w, h)//2 - 10
        
        # Create circular mask
        mask = np.zeros((h, w), dtype=np.uint8)
        cv2.circle(mask, center, radius, 255, -1)
        
        # Apply mask to all channels
        r, g, b = cv2.split(img_array)
        r = cv2.bitwise_and(r, r, mask=mask)
        g = cv2.bitwise_and(g, g, mask=mask)
        b = cv2.bitwise_and(b, b, mask=mask)
        
        # Merge channels
        masked = cv2.merge([r, g, b])
        
        # Contrast enhancement within mask
        lab = cv2.cvtColor(masked, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        l = clahe.apply(l)
        enhanced = cv2.merge([l, a, b])
        enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2RGB)
        
        return enhanced, True
    except Exception as e:
        print(f"Error preprocessing image {img_path}: {e}")
        return None, False

# ─── FIND IMAGE FUNCTION ─────────────────────────────────────────────────────
def find_image_path(img_id):
    """Handle different image extensions for Messidor-2."""
    # Try direct path first (with extension from CSV)
    direct_path = os.path.join(IMAGE_DIR, img_id)
    if os.path.exists(direct_path):
        return direct_path
    
    # Try without extension and check for various formats
    img_base = os.path.splitext(img_id)[0]
    for ext in ['.png', '.jpg', '.jpeg', '.tif', '.tiff']:
        test_path = os.path.join(IMAGE_DIR, f"{img_base}{ext}")
        if os.path.exists(test_path):
            return test_path
    
    return None

# ─── LOAD METADATA ─────────────────────────────────────────────────────────
print(f"Loading data from {CSV_IN}")
df = pd.read_csv(CSV_IN)
print(f"Found {len(df)} images in CSV")

# Print column names to debug
print("CSV columns:", df.columns.tolist())

# Initialize storage
theta_rows = []
all_thetas = []  # Will store all theta arrays
all_ids = []     # Will store corresponding image IDs
errors = []      # Will store model fit quality

# ─── PROCESS EACH IMAGE ────────────────────────────────────────────────────
start_time = time.time()

for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing Images"):
    # Messidor-2 uses 'path' for image filename
    img_id = row["path"] 
    
    # Get DR grade and other metadata
    label = row["label"]  # DR grade
    dme = row.get("adjudicated_dme", None)  # DME status
    gradable = row.get("adjudicated_gradable", None)  # Whether image is gradable
    
    # Find the image path (handling different extensions)
    img_path = find_image_path(img_id)
    
    if not img_path:
        print(f"Could not find image for {img_id}")
        continue
    
    # Preprocess image
    processed_img, success = preprocess_retinal_image(img_path)
    if not success:
        continue
    
    # Radial sampling approach
    h, w = processed_img.shape[:2]
    center = (w//2, h//2)
    max_radius = min(w, h)//2 - 10
    
    # Create samples from center outward (radial trajectory)
    X = []
    for r in range(max_radius):
        # Sample points on a circle with radius r
        circle_points = []
        for angle in range(0, 360, 10):  # Sample every 10 degrees
            angle_rad = np.deg2rad(angle)
            x = int(center[0] + r * np.cos(angle_rad))
            y = int(center[1] + r * np.sin(angle_rad))
            if 0 <= x < w and 0 <= y < h:
                circle_points.append(processed_img[y, x])
        
        if circle_points:
            X.append(np.mean(circle_points, axis=0))  # Average RGB values on the circle
    
    # Convert to numpy array
    X = np.array(X)
    
    # Create time derivative (using finite difference)
    if len(X) < 2:
        print(f"Error: Not enough valid samples for {img_id}")
        continue
        
    dX = np.zeros_like(X)
    dX[:-1] = X[1:] - X[:-1]  # Forward difference
    dX[-1] = dX[-2]  # Use last valid derivative for the final point
    
    # Fit SINDy model
    try:
        SINDY_MODEL.fit(X, t=np.arange(len(X)), x_dot=dX)
        theta = SINDY_MODEL.coefficients()
        
        # Store the theta array and ID
        all_thetas.append(theta)
        all_ids.append(img_id)
        
        # Save model accuracy
        score = SINDY_MODEL.score(X, t=np.arange(len(X)), x_dot=dX)
        errors.append({"id_code": img_id, "score": score})
        
        # Flatten and record for CSV
        flat_theta = theta.flatten()
        
        # Create output row with Messidor-2 metadata
        row_out = {
            "id_code": img_id,
            "diagnosis": label
        }
        
        # Add additional metadata if available
        if dme is not None:
            row_out["dme"] = dme
        if gradable is not None:
            row_out["gradable"] = gradable
        
        # Add theta values
        for i, val in enumerate(flat_theta):
            row_out[f"theta_{i}"] = val
        
        theta_rows.append(row_out)
        
        # Print progress information periodically
        if (idx + 1) % 20 == 0 or (idx + 1) == len(df):
            elapsed = time.time() - start_time
            images_per_sec = (idx + 1) / elapsed if elapsed > 0 else 0
            eta_seconds = (len(df) - idx - 1) / images_per_sec if images_per_sec > 0 else 0
            eta_minutes = eta_seconds / 60
            print(f"Processed {idx+1}/{len(df)} - Speed: {images_per_sec:.2f} img/s - ETA: {eta_minutes:.1f} minutes")
            print(f"Latest: {img_id} - Score: {score:.4f} - Theta shape: {theta.shape}")
            
    except Exception as e:
        print(f"Error fitting SINDy model for {img_id}: {e}")
        continue

# ─── SAVE ALL THETAS AS SINGLE NPY FILE ─────────────────────────────────────
print("\nSaving results...")
# Convert lists to numpy arrays
all_thetas_array = np.array(all_thetas)
all_ids_array = np.array(all_ids)

# Save to single NPY files
np.save(THETA_NPY, all_thetas_array)
np.save(THETA_IDS, all_ids_array)

# ─── WRITE CSV WITH FLATTENED THETAS ───────────────────────────────────────
df_theta = pd.DataFrame(theta_rows)
df_theta.to_csv(CSV_OUT, index=False)

# ─── SAVE ERROR METRICS ───────────────────────────────────────────────────
df_errors = pd.DataFrame(errors)
df_errors.to_csv(ERROR_CSV, index=False)

# ─── PRINT SUMMARY ─────────────────────────────────────────────────────────
print("=" * 50)
print(f"Processing complete!")
print(f"Total images processed: {len(all_thetas)}/{len(df)}")
if len(all_thetas) > 0:
    print(f"Theta shape per image: {all_thetas_array.shape[1:]}")
    print(f"Total thetas array shape: {all_thetas_array.shape}")
    print(f"Average reconstruction score: {np.mean([e['score'] for e in errors]):.4f}")
print("=" * 50)
print(f"Saved outputs:")
print(f"1. All thetas: {THETA_NPY}")
print(f"2. Image IDs: {THETA_IDS}")
print(f"3. CSV with flattened thetas: {CSV_OUT}")
print(f"4. Error metrics: {ERROR_CSV}")