import os
import numpy as np
import pandas as pd
from PIL import Image
import cv2
import pysindy as ps
import time
from tqdm import tqdm

# ─── CONFIG ────────────────────────────────────────────────────────────────
IMAGE_DIR = "/drive2/kt/SGD/Mesidor-1_data/images"
CSV_IN    = "/drive2/kt/SGD/Mesidor-1_data/mesidor_data.csv"
CSV_OUT   = "/drive2/Kuntal/Pysindy-experiment/M1-output/messidor_thetas.csv"
THETA_NPY = "/drive2/Kuntal/Pysindy-experiment/M1-output/messidor_all_thetas.npy"  # Single NPY file
THETA_IDS = "/drive2/Kuntal/Pysindy-experiment/M1-output/messidor_theta_ids.npy"   # Corresponding IDs
ERROR_CSV = "/drive2/Kuntal/Pysindy-experiment/M1-output/messidor_theta_errors.csv"

# Ensure output directory exists
os.makedirs(os.path.dirname(CSV_OUT), exist_ok=True)

# PySINDy settings
LIBRARY      = ps.PolynomialLibrary(degree=3)
OPTIMIZER    = ps.STLSQ(threshold=0.1, normalize_columns=True)
SINDY_MODEL  = ps.SINDy(feature_library=LIBRARY, optimizer=OPTIMIZER)

# ─── PREPROCESSING FUNCTION ─────────────────────────────────────────────────
def preprocess_retinal_image(img_path, size=(512, 512)):
    try:
        # Load image - handle .tif format for Messidor
        img = Image.open(img_path)
        img = img.convert("RGB")  # Ensure RGB format
        img = img.resize(size)
        img_array = np.array(img)
        
        # Create mask for fundus region
        h, w = img_array.shape[:2]
        center = (w//2, h//2)
        radius = min(w, h)//2 - 10
        
        # Create circular mask
        mask = np.zeros((h, w), dtype=np.uint8)
        cv2.circle(mask, center, radius, 255, -1)
        
        # Apply mask to all channels
        r, g, b = cv2.split(img_array)
        r = cv2.bitwise_and(r, r, mask=mask)
        g = cv2.bitwise_and(g, g, mask=mask)
        b = cv2.bitwise_and(b, b, mask=mask)
        
        # Merge channels
        masked = cv2.merge([r, g, b])
        
        # Contrast enhancement within mask
        lab = cv2.cvtColor(masked, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        l = clahe.apply(l)
        enhanced = cv2.merge([l, a, b])
        enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2RGB)
        
        return enhanced, True
    except Exception as e:
        print(f"Error preprocessing image {img_path}: {e}")
        return None, False

# ─── LOAD METADATA ─────────────────────────────────────────────────────────
print(f"Loading data from {CSV_IN}")
df = pd.read_csv(CSV_IN)
print(f"Found {len(df)} images in CSV")

# Print column names to debug
print("CSV columns:", df.columns.tolist())

# Map expected column names to actual column names (case insensitive matching)
column_map = {}
expected_columns = ["Image", "Retinopathy grade", "Risk of macular edema", "Ophthalmologic department"]
for expected in expected_columns:
    found = False
    for actual in df.columns:
        if expected.lower() == actual.lower() or expected.lower() in actual.lower():
            column_map[expected] = actual
            found = True
            break
    if not found:
        # Use a default or print warning
        print(f"WARNING: Could not find column matching '{expected}'")
        column_map[expected] = None

print("Column mapping:", column_map)

# Initialize storage
theta_rows = []
all_thetas = []  # Will store all theta arrays
all_ids = []     # Will store corresponding image IDs
errors = []      # Will store model fit quality

# ─── PROCESS EACH IMAGE ────────────────────────────────────────────────────
start_time = time.time()

for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing Images"):
    # Get image ID - safely access columns that might not exist
    img_id = row[column_map["Image"]] if column_map["Image"] else f"image_{idx}"
    
    # Get retinopathy grade safely
    if column_map["Retinopathy grade"]:
        label = row[column_map["Retinopathy grade"]]
    else:
        # Try common alternatives
        for col in df.columns:
            if "grade" in col.lower() or "retinopathy" in col.lower():
                label = row[col]
                break
        else:
            label = 0  # Default if not found
    
    # Create full image path (Messidor images are in .tif format)
    img_path = os.path.join(IMAGE_DIR, img_id)
    
    if not os.path.exists(img_path):
        print(f"Could not find image for {img_id}")
        continue
    
    # Preprocess image
    processed_img, success = preprocess_retinal_image(img_path)
    if not success:
        continue
    
    # Radial sampling approach
    h, w = processed_img.shape[:2]
    center = (w//2, h//2)
    max_radius = min(w, h)//2 - 10
    
    # Create samples from center outward (radial trajectory)
    X = []
    for r in range(max_radius):
        # Sample points on a circle with radius r
        circle_points = []
        for angle in range(0, 360, 10):  # Sample every 10 degrees
            angle_rad = np.deg2rad(angle)
            x = int(center[0] + r * np.cos(angle_rad))
            y = int(center[1] + r * np.sin(angle_rad))
            if 0 <= x < w and 0 <= y < h:
                circle_points.append(processed_img[y, x])
        
        if circle_points:
            X.append(np.mean(circle_points, axis=0))  # Average RGB values on the circle
    
    # Convert to numpy array
    X = np.array(X)
    
    # Create time derivative (using finite difference)
    if len(X) < 2:
        print(f"Error: Not enough valid samples for {img_id}")
        continue
        
    dX = np.zeros_like(X)
    dX[:-1] = X[1:] - X[:-1]  # Forward difference
    dX[-1] = dX[-2]  # Use last valid derivative for the final point
    
    # Fit SINDy model
    try:
        SINDY_MODEL.fit(X, t=np.arange(len(X)), x_dot=dX)
        theta = SINDY_MODEL.coefficients()
        
        # Store the theta array and ID
        all_thetas.append(theta)
        all_ids.append(img_id)
        
        # Save model accuracy
        score = SINDY_MODEL.score(X, t=np.arange(len(X)), x_dot=dX)
        errors.append({"id_code": img_id, "score": score})
        
        # Flatten and record for CSV
        flat_theta = theta.flatten()
        
        # Create output row with available Messidor metadata
        row_out = {"id_code": img_id, "diagnosis": label}
        
        # Add edema risk if available
        if column_map["Risk of macular edema"]:
            row_out["edema_risk"] = row[column_map["Risk of macular edema"]]
            
        # Add department if available
        if column_map["Ophthalmologic department"]:
            row_out["department"] = row[column_map["Ophthalmologic department"]]
        
        # Add theta values
        for i, val in enumerate(flat_theta):
            row_out[f"theta_{i}"] = val
        
        theta_rows.append(row_out)
        
        # Print progress information periodically
        if (idx + 1) % 10 == 0 or (idx + 1) == len(df):
            elapsed = time.time() - start_time
            images_per_sec = (idx + 1) / elapsed if elapsed > 0 else 0
            eta_seconds = (len(df) - idx - 1) / images_per_sec if images_per_sec > 0 else 0
            print(f"Processed {idx+1}/{len(df)} - Speed: {images_per_sec:.2f} img/s - ETA: {eta_seconds:.0f} seconds")
            print(f"Latest: {img_id} - Score: {score:.4f} - Theta shape: {theta.shape}")
            
    except Exception as e:
        print(f"Error fitting SINDy model for {img_id}: {e}")
        continue

# ─── SAVE ALL THETAS AS SINGLE NPY FILE ─────────────────────────────────────
print("\nSaving results...")
# Convert lists to numpy arrays
all_thetas_array = np.array(all_thetas)
all_ids_array = np.array(all_ids)

# Save to single NPY files
np.save(THETA_NPY, all_thetas_array)
np.save(THETA_IDS, all_ids_array)

# ─── WRITE CSV WITH FLATTENED THETAS ───────────────────────────────────────
df_theta = pd.DataFrame(theta_rows)
df_theta.to_csv(CSV_OUT, index=False)

# ─── SAVE ERROR METRICS ───────────────────────────────────────────────────
df_errors = pd.DataFrame(errors)
df_errors.to_csv(ERROR_CSV, index=False)

# ─── PRINT SUMMARY ─────────────────────────────────────────────────────────
print("=" * 50)
print(f"Processing complete!")
print(f"Total images processed: {len(all_thetas)}/{len(df)}")
if len(all_thetas) > 0:
    print(f"Theta shape per image: {all_thetas_array.shape[1:]}")
    print(f"Total thetas array shape: {all_thetas_array.shape}")
    print(f"Average reconstruction score: {np.mean([e['score'] for e in errors]):.4f}")
print("=" * 50)
print(f"Saved outputs:")
print(f"1. All thetas: {THETA_NPY}")
print(f"2. Image IDs: {THETA_IDS}")
print(f"3. CSV with flattened thetas: {CSV_OUT}")
print(f"4. Error metrics: {ERROR_CSV}")