import pandas as pd
import os
import os
import cv2
import pandas as pd
from PIL import Image
import torch
from tqdm import tqdm
import numpy as np
import argparse
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import warnings

BASE_DIR = '/home/datasets/EPIC-KITCHENS'
# VIDEO_DIR = os.path.join(BASE_DIR, "P04", "videos")
ACTION_RECOGNITION_DIR = os.path.join(BASE_DIR, 'action_recognition')

TRAIN_DF = os.path.join(ACTION_RECOGNITION_DIR, 'EPIC_100_train.csv')
VALIDATION_DF = os.path.join(ACTION_RECOGNITION_DIR, 'EPIC_100_validation.csv')
FRAME_SIZE = (224, 224)
DETECTION_THRESHOLD = 0.45
OUTPUT_DIR = "epic_image_pairs"
FRAME_MARGIN = 0
os.makedirs(OUTPUT_DIR, exist_ok=True)


VERBS = ['add', 'adjut', 'attach', 'break', 'close', 'cut', 'divide',
         'eat', 'empty', 'fill', 'flatten', 'flip', 'fold', 'increase',
         'lift', 'move', 'open', 'peel', 'pull', 'remove', 'roll', 'sharpen',
         'stretch']
NOUNS = ['meat', 'oil', 'pan', 'potato', 'salt', 'sauce', 'bag', 'hob', 'lid', 'oven',
         'tap', 'kettle', 'chicken', 'egg', 'garlic', 'bin', 'bottle', 'box', 'container',
         'cupboard', 'dishwater', 'drawer', 'fridge', 'package', 'spice', 'aubergine',
         'bacon', 'bread', 'carrot', 'cheese', 'courgette', 'cucumber', 'mushroom', 'olive',
         'omelette', 'onion', 'peach', 'pepper', 'pizza', 'sausage', 'tomato', 'broccoli',
         'dough', 'plate', 'apple', 'food', 'pasta', 'rice', 'salad', 'spoon', 'bowl', 'coffee',
         'cup', 'jar', 'rack:drying', 'sink', 'glass', 'liquid', 'pot', 'board:chopping', 'mat',
         'paper', 'knife', 'liquid:washing', 'banana', 'squash', 'tray', 'leaf', 'mixture',
         'onion:spring', 'seed']




        

class GroundingDinoDetector:
    def __init__(self, model_id="IDEA-Research/grounding-dino-base"):
        print(f"Loading Grounding DINO model: {model_id}")
        try:

            self.processor = AutoProcessor.from_pretrained(model_id)
            self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id)
        except Exception as e:
            print(f"Error: {e}")
            raise

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = self.model.to(self.device)
        self.model.eval()
        print(f"Grounding DINO detector initialized on device: {self.device}")

    def detect_object(self, image_input, target_noun: str) -> float:
        if isinstance(image_input, np.ndarray):
            # Assuming input numpy array is RGB as per your extract_frames logic
            pil_image = Image.fromarray(image_input.astype(np.uint8))
        elif isinstance(image_input, Image.Image):
            pil_image = image_input
        else:
            print("Error: Invalid image input type. Must be numpy array or PIL Image.")
            return 0.0

        if pil_image.mode != "RGB":
            pil_image = pil_image.convert("RGB")

        processed_noun = target_noun.replace(":", " ").strip()
        text_query = f"{processed_noun}."
        
        try:
            inputs = self.processor(images=pil_image, text=text_query, return_tensors="pt").to(self.device)
            
            with torch.no_grad():
                outputs = self.model(**inputs)
            
            image_width, image_height = pil_image.size
            target_sizes = torch.tensor([[image_height, image_width]]).to(self.device) 

            with warnings.catch_warnings(): 
                warnings.simplefilter("ignore")
                results = self.processor.post_process_grounded_object_detection(
                    outputs, 
                    threshold=0.1,
                    target_sizes=target_sizes
                )
            
            current_image_results = results[0]
            scores = current_image_results["scores"]

            max_confidence_for_target = 0.0
            if scores.numel() > 0:
                max_confidence_for_target = scores.max().item()
            
            return max_confidence_for_target

        except Exception as e:
            print(f"Error during Grounding DINO detection for noun '{target_noun}' (query: '{text_query}'): {e}")
            import traceback
            traceback.print_exc()
            return 0.0

        
def extract_frames(video_path, start_frame_num, stop_frame_num):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video {video_path}")
        return None, None
    
    # Get the before frame
    before_frame_idx = start_frame_num  #- 1
    cap.set(cv2.CAP_PROP_POS_FRAMES, before_frame_idx)
    ret_before, frame_before_bgr = cap.read()
    
    # Get the after frame
    after_frame_idx = stop_frame_num #+ 1
    
    cap.set(cv2.CAP_PROP_POS_FRAMES, after_frame_idx)
    ret_after, frame_after_bgr = cap.read()
    
    cap.release()
    
    if not ret_before:
        print(f"Error: Could not read before_frame ({before_frame_idx}) from {video_path}")
        return None, None
    if not ret_after:
        print(f"Error: Could not read after_frame ({after_frame_idx}) from {video_path}")
        return None, None
        
    frame_before_rgb = cv2.cvtColor(frame_before_bgr, cv2.COLOR_BGR2RGB)
    frame_after_rgb = cv2.cvtColor(frame_after_bgr, cv2.COLOR_BGR2RGB)
    
    return frame_before_rgb, frame_after_rgb


def process_epic_kitchens(df, base_video_dir, output_dir, detector):
    df_filtered = df[df['verb'].isin(VERBS)].copy()
    
    os.makedirs(output_dir, exist_ok=True)

    valid_pairs = []
    
    print(f"Processing {len(df_filtered)} narration segments...")
    for index, row in tqdm(df_filtered.iterrows(), total=len(df_filtered)):
        video_id_short = row['video_id'] # e.g., P01_01
        participant_id = row['participant_id'] # e.g., P01
        narration_id = row['narration_id']
        start_frame = int(row['start_frame'])
        stop_frame = int(row['stop_frame'])
        target_noun = row['noun']
        
        video_path_template = os.path.join(base_video_dir, participant_id, 'videos', f"{video_id_short}.{{ext}}")
        
        video_path = video_path_template.format(ext="MP4")
        if not os.path.exists(video_path):
            video_path = video_path_template.format(ext="mp4")
            if not os.path.exists(video_path):
                flat_video_path_template = os.path.join(base_video_dir, f"{video_id_short}.{{ext}}")
                video_path = flat_video_path_template.format(ext="MP4")
                if not os.path.exists(video_path):
                    video_path = flat_video_path_template.format(ext="mp4")
                    if not os.path.exists(video_path):
                        print(f"Warning: Video file not found for {video_id_short} (participant {participant_id}). Searched multiple patterns. Skipping.")
                        continue
        

        frame_before_rgb, frame_after_rgb = extract_frames(video_path, start_frame, stop_frame)
        
        if frame_before_rgb is None or frame_after_rgb is None:
            continue
        
        before_score = detector.detect_object(frame_before_rgb, target_noun)
        after_score = detector.detect_object(frame_after_rgb, target_noun)
        
        if before_score >= DETECTION_THRESHOLD and after_score >= DETECTION_THRESHOLD:
            before_filename = f"{narration_id}_before.jpg"
            after_filename = f"{narration_id}_after.jpg"
            
            cv2.imwrite(os.path.join(output_dir, before_filename), cv2.cvtColor(frame_before_rgb, cv2.COLOR_RGB2BGR))
            cv2.imwrite(os.path.join(output_dir, after_filename), cv2.cvtColor(frame_after_rgb, cv2.COLOR_RGB2BGR))
            
            valid_pairs.append({
                'narration_id': narration_id,
                'video_id': video_id_short,
                'participant_id': participant_id,
                'verb_class': row['verb'],
                'verb': row['verb_class'],
                'noun_class': target_noun,
                'noun': row['noun_class'],
                'before_score': before_score,
                'after_score': after_score,
                'before_image_path': os.path.join(output_dir, before_filename),
                'after_image_path': os.path.join(output_dir, after_filename)
            })
    
    valid_df = pd.DataFrame(valid_pairs)
    if not valid_df.empty:
        # Create metadata file path
        metadata_file = os.path.join(output_dir, "valid_image_pairs_metadata.csv")
        
        # If file exists, append without header. If not, create with header
        if os.path.exists(metadata_file):
            valid_df.to_csv(metadata_file, mode='a', header=False, index=False)
            print(f"Appended {len(valid_df)} new pairs to valid_image_pairs_metadata.csv")
        else:
            valid_df.to_csv(metadata_file, index=False)
            print(f"Created new metadata file with {len(valid_df)} pairs")
    else:
        print("No valid pairs found meeting the detection criteria.")
    
    return valid_df




if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--i", type=int)

    args = parser.parse_args()

    train_df = pd.read_csv(TRAIN_DF)
    val_df = pd.read_csv(VALIDATION_DF)
    merged_df = pd.concat([train_df, val_df], ignore_index=True)

    chunk_size = 1000
    start_idx = args.i * chunk_size
    end_idx = start_idx + chunk_size

    chunk_df = merged_df.iloc[start_idx:end_idx]


    detector = GroundingDinoDetector(model_id="IDEA-Research/grounding-dino-base")
    process_epic_kitchens(chunk_df, BASE_DIR, OUTPUT_DIR, detector)
