import os
import numpy as np
import pandas as pd
from PIL import Image
import cv2
import pysindy as ps
import time
from tqdm import tqdm
import gc  # Garbage collection for memory management

# ─── CONFIG ────────────────────────────────────────────────────────────────
# EyePACS dataset has multiple directories
IMG_DIRS = [
    "/drive2/kt/eyepac_data/train_1/train",
    "/drive2/kt/eyepac_data/train_2/train",
    "/drive2/kt/eyepac_data/train_3/train",
    "/drive2/kt/eyepac_data/train_4/train",
    "/drive2/kt/eyepac_data/train_5/train",
]
CSV_IN = "/drive2/kt/SGD/trainLabels.csv"  # EyePACS label file
OUTPUT_DIR = "/drive2/Kuntal/Pysindy-experiment/eyepacs_theta_data"  # Base output directory

# Output file paths
CSV_OUT = f"{OUTPUT_DIR}/eyepacs_thetas.csv"
THETA_NPY = f"{OUTPUT_DIR}/eyepacs_all_thetas.npy"
THETA_IDS = f"{OUTPUT_DIR}/eyepacs_theta_ids.npy"
ERROR_CSV = f"{OUTPUT_DIR}/theta_errors.csv"

# Checkpoint files (to resume processing if interrupted)
CHECKPOINT_DIR = f"{OUTPUT_DIR}/checkpoints"
BATCH_SIZE = 1000  # Process images in batches of 1000

# Ensure output directories exist
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# PySINDy settings
LIBRARY = ps.PolynomialLibrary(degree=3)
OPTIMIZER = ps.STLSQ(threshold=0.1, normalize_columns=True)
SINDY_MODEL = ps.SINDy(feature_library=LIBRARY, optimizer=OPTIMIZER)

# ─── PREPROCESSING FUNCTION ─────────────────────────────────────────────────
def preprocess_retinal_image(img_path, size=(512, 512)):
    try:
        # Load image
        img = Image.open(img_path).convert("RGB")
        img = img.resize(size)
        img_array = np.array(img)
        
        # Create mask for fundus region
        h, w = img_array.shape[:2]
        center = (w//2, h//2)
        radius = min(w, h)//2 - 10
        
        # Create circular mask
        mask = np.zeros((h, w), dtype=np.uint8)
        cv2.circle(mask, center, radius, 255, -1)
        
        # Apply mask to all channels
        r, g, b = cv2.split(img_array)
        r = cv2.bitwise_and(r, r, mask=mask)
        g = cv2.bitwise_and(g, g, mask=mask)
        b = cv2.bitwise_and(b, b, mask=mask)
        
        # Merge channels
        masked = cv2.merge([r, g, b])
        
        # Contrast enhancement within mask
        lab = cv2.cvtColor(masked, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        l = clahe.apply(l)
        enhanced = cv2.merge([l, a, b])
        enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2RGB)
        
        return enhanced, True
    except Exception as e:
        print(f"Error preprocessing image {img_path}: {e}")
        return None, False

# ─── FIND IMAGE FUNCTION ─────────────────────────────────────────────────────
def find_image_path(img_id):
    """Search for image in all directories with various extensions."""
    extensions = ['.jpeg', '.jpg', '.png']
    
    for img_dir in IMG_DIRS:
        for ext in extensions:
            test_path = os.path.join(img_dir, f"{img_id}{ext}")
            if os.path.exists(test_path):
                return test_path
    
    return None

# ─── LOAD CHECKPOINT ─────────────────────────────────────────────────────────
def load_checkpoint():
    """Load the latest checkpoint if available."""
    checkpoints = sorted([f for f in os.listdir(CHECKPOINT_DIR) if f.startswith('batch_')])
    
    if not checkpoints:
        return [], [], [], [], 0
    
    latest_checkpoint = os.path.join(CHECKPOINT_DIR, checkpoints[-1])
    print(f"Loading checkpoint: {latest_checkpoint}")
    
    checkpoint_data = np.load(latest_checkpoint, allow_pickle=True)
    all_thetas = checkpoint_data['thetas'].tolist()
    all_ids = checkpoint_data['ids'].tolist()
    theta_rows = checkpoint_data['rows'].tolist()
    errors = checkpoint_data['errors'].tolist()
    completed_count = len(all_ids)
    
    return all_thetas, all_ids, theta_rows, errors, completed_count

# ─── SAVE CHECKPOINT ─────────────────────────────────────────────────────────
def save_checkpoint(all_thetas, all_ids, theta_rows, errors, batch_num):
    """Save current progress as a checkpoint."""
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"batch_{batch_num:04d}.npz")
    
    np.savez(
        checkpoint_path,
        thetas=np.array(all_thetas, dtype=object),
        ids=np.array(all_ids),
        rows=np.array(theta_rows, dtype=object),
        errors=np.array(errors, dtype=object)
    )
    
    print(f"Saved checkpoint: {checkpoint_path}")

# ─── PROCESS IMAGES IN BATCHES ───────────────────────────────────────────────
def process_images(df, start_idx=0):
    """Process images in batches with checkpointing."""
    
    # Initialize or load from checkpoint
    all_thetas, all_ids, theta_rows, errors, completed_count = load_checkpoint()
    
    # Skip already processed images
    if completed_count > 0:
        print(f"Resuming from image {completed_count} of {len(df)}")
        start_idx = completed_count
    
    # Process in batches
    batch_num = start_idx // BATCH_SIZE
    start_time = time.time()
    
    for idx in tqdm(range(start_idx, len(df)), desc="Processing images"):
        row = df.iloc[idx]
        img_id = row["image"]  # EyePACS CSV uses "image" instead of "id_code"
        label = row["level"]   # EyePACS CSV uses "level" instead of "diagnosis"
        
        # Find image path
        img_path = find_image_path(img_id)
        
        if img_path is None:
            print(f"Could not find image for {img_id}")
            continue
        
        # Preprocess image
        processed_img, success = preprocess_retinal_image(img_path)
        if not success:
            continue
        
        # Radial sampling approach
        h, w = processed_img.shape[:2]
        center = (w//2, h//2)
        max_radius = min(w, h)//2 - 10
        
        # Create samples from center outward (radial trajectory)
        X = []
        for r in range(max_radius):
            # Sample points on a circle with radius r
            circle_points = []
            for angle in range(0, 360, 10):  # Sample every 10 degrees
                angle_rad = np.deg2rad(angle)
                x = int(center[0] + r * np.cos(angle_rad))
                y = int(center[1] + r * np.sin(angle_rad))
                if 0 <= x < w and 0 <= y < h:
                    circle_points.append(processed_img[y, x])
            
            if circle_points:
                X.append(np.mean(circle_points, axis=0))  # Average RGB values on the circle
        
        # Convert to numpy array
        X = np.array(X)
        
        # Create time derivative (using finite difference)
        if len(X) < 2:
            print(f"Error: Not enough valid samples for {img_id}")
            continue
            
        dX = np.zeros_like(X)
        dX[:-1] = X[1:] - X[:-1]  # Forward difference
        dX[-1] = dX[-2]  # Use last valid derivative for the final point
        
        # Fit SINDy model
        try:
            SINDY_MODEL.fit(X, t=np.arange(len(X)), x_dot=dX)
            theta = SINDY_MODEL.coefficients()
            
            # Store the theta array and ID
            all_thetas.append(theta)
            all_ids.append(img_id)
            
            # Save model accuracy
            score = SINDY_MODEL.score(X, t=np.arange(len(X)), x_dot=dX)
            errors.append({"id_code": img_id, "score": score})
            
            # Flatten and record for CSV
            flat_theta = theta.flatten()
            row_out = {"id_code": img_id, "diagnosis": label}
            for i, val in enumerate(flat_theta):
                row_out[f"theta_{i}"] = val
            theta_rows.append(row_out)
            
            # Print progress 
            if (idx + 1) % 50 == 0:
                elapsed = time.time() - start_time
                images_per_sec = (idx + 1 - start_idx) / elapsed if elapsed > 0 else 0
                eta_seconds = (len(df) - idx - 1) / images_per_sec if images_per_sec > 0 else 0
                eta_hours = eta_seconds / 3600
                
                print(f"Processed {idx+1}/{len(df)} - Speed: {images_per_sec:.2f} img/s - ETA: {eta_hours:.1f} hours")
            
            # Save checkpoint after each batch
            current_batch = (idx + 1) // BATCH_SIZE
            if current_batch > batch_num:
                batch_num = current_batch
                save_checkpoint(all_thetas, all_ids, theta_rows, errors, batch_num)
                
                # Explicitly trigger garbage collection to manage memory
                gc.collect()
                
                # Optionally save intermediate CSV files
                save_intermediate_results(all_thetas, all_ids, theta_rows, errors, batch_num)
                
        except Exception as e:
            print(f"Error fitting SINDy model for {img_id}: {e}")
            continue
    
    return all_thetas, all_ids, theta_rows, errors

# ─── SAVE INTERMEDIATE RESULTS ─────────────────────────────────────────────────
def save_intermediate_results(all_thetas, all_ids, theta_rows, errors, batch_num):
    """Save intermediate results for large datasets."""
    
    # Save CSV with current processed data
    batch_csv = f"{OUTPUT_DIR}/intermediate_batch_{batch_num:04d}.csv"
    pd.DataFrame(theta_rows).to_csv(batch_csv, index=False)
    
    # Save error metrics
    batch_errors = f"{OUTPUT_DIR}/errors_batch_{batch_num:04d}.csv"
    pd.DataFrame(errors).to_csv(batch_errors, index=False)
    
    print(f"Saved intermediate results for batch {batch_num}")

# ─── MAIN EXECUTION ─────────────────────────────────────────────────────────
def main():
    # Load CSV data
    print(f"Loading CSV data from {CSV_IN}")
    df = pd.read_csv(CSV_IN)
    print(f"Found {len(df)} images in CSV")
    
    # Process images
    print("Starting image processing...")
    all_thetas, all_ids, theta_rows, errors = process_images(df)
    
    # Convert lists to numpy arrays
    print("Converting to numpy arrays...")
    all_thetas_array = np.array(all_thetas)
    all_ids_array = np.array(all_ids)
    
    # Save final results
    print("Saving final results...")
    
    # Save to NPY files
    np.save(THETA_NPY, all_thetas_array)
    np.save(THETA_IDS, all_ids_array)
    
    # Write CSV with flattened thetas
    df_theta = pd.DataFrame(theta_rows)
    df_theta.to_csv(CSV_OUT, index=False)
    
    # Save error metrics
    df_errors = pd.DataFrame(errors)
    df_errors.to_csv(ERROR_CSV, index=False)
    
    # Print summary
    print("=" * 50)
    print(f"Processing complete!")
    print(f"Total images processed: {len(all_thetas)}/{len(df)}")
    
    if len(all_thetas) > 0:
        print(f"Theta shape per image: {all_thetas_array.shape[1:]}")
        print(f"Total thetas array shape: {all_thetas_array.shape}")
        print(f"Average reconstruction score: {np.mean([e['score'] for e in errors]):.4f}")
    
    print("=" * 50)
    print(f"Saved outputs:")
    print(f"1. All thetas: {THETA_NPY}")
    print(f"2. Image IDs: {THETA_IDS}")
    print(f"3. CSV with flattened thetas: {CSV_OUT}")
    print(f"4. Error metrics: {ERROR_CSV}")

if __name__ == "__main__":
    main()