import torch
import numpy as np
import os
import json


import openai
from groundingdino.util.inference import load_model, load_image, predict, annotate
from groundingdino.util.slconfig import SLConfig
from groundingdino.models import build_model
from groundingdino.util.utils import clean_state_dict
from segment_anything import sam_model_registry, SamPredictor


class ConceptPipeline:
    def __init__(self, api_key=None):
        self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
        if not self.api_key:
            print("No OpenAI API Key provided. Concept generation will fail if cache is missing.")
        
        self.cache_path = os.path.join(config.CONCEPT_CACHE_DIR, "concepts.json")
        os.makedirs(config.CONCEPT_CACHE_DIR, exist_ok=True)
        
    def load_or_generate_concepts(self, class_list):
        """
        Loads concepts from cache or generates them via LLM.
        """
        if os.path.exists(self.cache_path):
            print(f"Loading concepts from {self.cache_path}...")
            with open(self.cache_path, 'r') as f:
                return json.load(f)
        
        if not self.api_key:
            raise ValueError("OpenAI API Key required to generate new concepts.")

        print(f"Generating concepts for {len(class_list)} classes using LLM...")
        client = openai.OpenAI(api_key=self.api_key)
        concepts_db = {}
        
        for cls_name in class_list:
            # like the paper (Label-free concept bottleneck)
            prompt = (
                f"List the most important visual parts or attributes that distinguish a '{cls_name}' "
                f"from other objects. Return only a comma-separated list of short physical descriptions "
                f"(e.g., 'long beak, red wings'). Do not include abstract concepts."
            )
            
            try:
                response = client.chat.completions.create(
                    model="gpt-4o-mini", # Using the model specified in paper
                    messages=[{"role": "user", "content": prompt}],
                    temperature=0.7
                )
                text = response.choices[0].message.content.strip()
                concepts = [c.strip().lower() for c in text.split(',')]
                concepts_db[cls_name] = concepts
                print(f"Generated for {cls_name}: {concepts}")
            except Exception as e:
                print(f"Error generating for {cls_name}: {e}")
                concepts_db[cls_name] = [cls_name] # Fallback
        
        with open(self.cache_path, 'w') as f:
            json.dump(concepts_db, f, indent=4)
            
        return concepts_db

    def validate_concepts(self, concepts_db, validation_dataset, grounder):
        """
        Validates concepts based on Occurrence Rate (>15%) and Spatial Coverage (>20%).
        Requires a validation dataset with ground truth masks (or reliable estimates).
        """
        print("Validating concepts...")
        validated_db = {}
        

        for cls_name, candidates in concepts_db.items():
            # Get validation samples for this class (P=30)
            samples = validation_dataset.get_samples_for_class(cls_name, k=config.SAMPLES_PER_CLASS_FOR_CONCEPTS)
            
            valid_concepts = []
            for concept in candidates:
                occurrences = 0
                total_iou = 0
                
                for img, gt_mask in samples:
                    # Get concept mask
                    concept_mask = grounder.get_concept_mask(img, [concept])
                    
                    # Check occurrence (is the concept present?)
                    if concept_mask.sum() > 0:
                        occurrences += 1
                        
                    # Check spatial coverage (IoU with Object Ground Truth)
                    # Intersection / Union
                    inter = (concept_mask * gt_mask).sum()
                    union = (concept_mask + gt_mask).clamp(0, 1).sum()
                    if union > 0:
                        total_iou += (inter / union).item()
                
                occurrence_rate = occurrences / len(samples)
                avg_iou = total_iou / len(samples) # Simplification of coverage metric
                
                if occurrence_rate >= config.OCCURRENCE_THRESHOLD and avg_iou >= config.COVERAGE_THRESHOLD:
                    valid_concepts.append(concept)
            
            validated_db[cls_name] = valid_concepts
            
        return validated_db


class GroundingSAMWrapper:
    def __init__(self, 
                 dino_config_path="GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
                 dino_checkpoint="weights/groundingdino_swint_ogc.pth",
                 sam_checkpoint="weights/sam_vit_h_4b8939.pth"):
        
        self.device = config.DEVICE
        
        try:
            self.grounding_dino = load_model(dino_config_path, dino_checkpoint, device=self.device)
            
            self.sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
            self.sam.to(device=self.device)
            self.sam_predictor = SamPredictor(self.sam)
            self.is_ready = True
        except Exception as e:
            print(f"Failed to load GroundingSAM models: {e}")
            self.is_ready = False

    def get_concept_mask(self, image_np, concepts, box_threshold=0.3, text_threshold=0.25):
        """
        Generates S(I): The semantic guidance mask.

        """
        if not self.is_ready:
            # Fallback for testing without weights
            H, W = image_np.shape[:2]
            return torch.zeros((H, W))
            
        import groundingdino.datasets.transforms as T
        transform = T.Compose([
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
        image_pil = Image.fromarray(image_np)
        image_tensor, _ = transform(image_pil, None)
        
        prompt = " . ".join(concepts)
        
        boxes, logits, phrases = predict(
            model=self.grounding_dino,
            image=image_tensor,
            caption=prompt,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            device=self.device
        )
        
        if boxes.shape[0] == 0:
            return torch.zeros((image_np.shape[0], image_np.shape[1]))

        self.sam_predictor.set_image(image_np)
        
        H, W, _ = image_np.shape
        boxes_xyxy = boxes * torch.Tensor([W, H, W, H])
        boxes_xyxy = boxes_xyxy.to(self.device)
        
        transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(boxes_xyxy, image_np.shape[:2])
        
        masks, _, _ = self.sam_predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False,
        )
        
        final_mask = torch.any(masks.squeeze(1), dim=0).float().cpu()
        
        return final_mask

class MaskGenerator:
    """
    Main entry point to pre-process the dataset.
    """
    def __init__(self):
        self.pipeline = ConceptPipeline()
        self.grounder = GroundingSAMWrapper()
        
    def generate_dataset_masks(self, dataset_root, class_list):
        concepts_db = self.pipeline.load_or_generate_concepts(class_list)
        
        save_dir = config.MASK_CACHE_DIR
        os.makedirs(save_dir, exist_ok=True)
        
        print("Starting mask generation...")

        #   cls = get_class(img_path)
        #   mask = self.grounder.get_concept_mask(load_img(img_path), concepts_db[cls])
        #   save_mask(mask, save_dir / rel_path)