import os
import sys
import warnings
import torch
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np

# Set paths for feature matching and segmentation modules
generate_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'feature_matching'))
segment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'segmenter'))
sys.path.append(segment_path)
sys.path.append(generate_path)

# from segment_anything import sam_model_registry, SamPredictor
from segmenter.segment import process_image, loading_seg, seg_main, show_points
from feature_matching.generate_points import generate, loading_dino
from utils import calculate_distances, convert_to_edges, calculate_center_points

DATASET = '' 
CATAGORY = ''
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = 560
USE_BIDIRECTIONAL_MATCHING = True 

warnings.filterwarnings("ignore")

def create_graph_environment(features, positive_indices, negative_indices, image_size, device, max_steps, pos_weights=None, neg_weights=None):
    distances = calculate_distances(features, positive_indices, negative_indices, image_size, device)
    feature_pos_distances, feature_cross_distances, physical_pos_distances, physical_neg_distances, physical_cross_distances = distances
    
    edges = [
        convert_to_edges(positive_indices, positive_indices, feature_pos_distances),
        convert_to_edges(positive_indices, positive_indices, physical_pos_distances),
        convert_to_edges(negative_indices, negative_indices, physical_neg_distances),
        convert_to_edges(positive_indices, negative_indices, feature_cross_distances),
        convert_to_edges(positive_indices, negative_indices, physical_cross_distances)
    ]
    
    G = nx.MultiGraph()

    for idx in positive_indices.cpu().numpy():
        weight = pos_weights.get(idx, 1) if pos_weights else 1
        G.add_node(idx, category='pos', weight=weight)
        
    for idx in negative_indices.cpu().numpy():
        weight = neg_weights.get(idx, 1) if neg_weights else 1
        G.add_node(idx, category='neg', weight=weight)
    
    for edge_list, weight_type in zip(edges, ['feature_pos', 'physical_pos', 'physical_neg', 'feature_cross', 'physical_cross']):
        G.add_weighted_edges_from(edge_list, weight=weight_type)

# Define paths
BASE_DIR = " "
DATA_DIR = " "
REFERENCE_IMAGE_DIR = " "
MASK_DIR = " "
IMAGE_DIR =" "
RESULTS_DIR = " "
POINT_MASKS_DIR = " "
FINAL_PROMPTS_DIR = " "

# Ensure the results directories exist
os.makedirs(POINT_MASKS_DIR, exist_ok=True)
os.makedirs(FINAL_PROMPTS_DIR, exist_ok=True)  

# Load models for segmentation and feature generation
def load_models():
    """
    Load the segmentation model and DINO feature extractor.
    """
    try:
        model_seg = loading_seg('vitl', DEVICE)
        model_dino = loading_dino(DEVICE)
        return model_seg, model_dino
    except Exception as e:
        print(f"Error loading models: {e}")
        sys.exit(1)


if __name__ == "__main__":
    model_seg, model_dino = load_models()
    reference_images = sorted(os.listdir(REFERENCE_IMAGE_DIR))
    if not reference_images:
        print("No reference images found.")
        sys.exit(1)

    img_list = sorted(os.listdir(IMAGE_DIR))
    print(f"Found {len(img_list)} target images to process")

    for img_name in tqdm(img_list, desc="Processing images"):
        print(f"\nProcessing target image: {img_name}")
        
        reference_mask_pairs = []
        for ref_image in reference_images:
            mask_name = os.path.splitext(ref_image)[0]
            mask_path = None
            
            for ext in ['.png', '.jpg']:
                potential_mask_path = os.path.join(MASK_DIR, mask_name + ext)
                if os.path.exists(potential_mask_path):
                    mask_path = potential_mask_path
                    break
            
            if mask_path is not None:
                reference_mask_pairs.append((
                    os.path.join(REFERENCE_IMAGE_DIR, ref_image),
                    mask_path
                ))
        
        if not reference_mask_pairs:
            print(f"Warning: No reference-mask pairs found for {img_name}, skipping...")
            continue
            
        print(f"Found {len(reference_mask_pairs)} reference-mask pairs for {img_name}")

    print("\nAll images processed successfully!")
