import json
import os
import time
from typing import List, Dict, Any, Optional

import torch
from PIL import Image
from tqdm import tqdm
import argparse
from modules.object_detector import ObjectDetector
from modules.attribute_generator import AttributeGenerator
from modules.relation_extractor import RelationExtractor
from modules.model_loader import ModelLoader, MODEL_TYPES
from utils.image_processor import ImageProcessor
from utils.hierarchy_tree_builder import HierarchyTreeBuilder


class SceneGraphGenerator:
    
    def __init__(
        self,
        llmdet_model_path: str = "/home//hsgg/llmdet",
        llm_model_path: Optional[str] = None,
        model_type: MODEL_TYPES = "llavaonevision1_5",
        cuda_device: int = 6,
        det_threshold: float = 0.3,
        nms_threshold: float = 0.5,
        temperature: float = 0.1,
        use_vcd: bool = True,
        cd_alpha: float = 1.0,
        cd_beta: float = 0.1,
        noise_step: int = 500,
        amateur_use_image: bool = False,
        skip_attributes: bool = False
    ):

        if torch.cuda.is_available():
            torch.cuda.set_device(cuda_device)
            self.device = f"cuda:{cuda_device}"
          
        else:
            self.device = "cpu"
         
        
        self.model_type = model_type
        
 
        if llm_model_path is None:
            llm_model_path = ModelLoader.get_default_path(model_type)
        self.llm_model_path = llm_model_path
        

        self.llm_model, self.llm_processor, _ = ModelLoader.load_model(
            model_type=model_type,
            model_path=llm_model_path,
            device=self.device,
            output_attentions=False
        )
        

        self.object_detector = ObjectDetector(
            llmdet_model_path=llmdet_model_path,
            llm_model=self.llm_model,
            llm_processor=self.llm_processor,
            device=self.device,
            threshold=det_threshold,
            nms_threshold=nms_threshold
        )
        
        self.attribute_generator = AttributeGenerator(
            model=self.llm_model,
            processor=self.llm_processor,
            device=self.device,
            temperature=temperature
        )
        
        self.relation_extractor = RelationExtractor(
            model=self.llm_model,
            processor=self.llm_processor,
            device=self.device,
            use_vcd=use_vcd,
            cd_alpha=cd_alpha,
            cd_beta=cd_beta,
            noise_step=noise_step,
            amateur_use_image=amateur_use_image,
            temperature=temperature
        )
        
      
        self.image_processor = ImageProcessor()
        

        self.tree_builder = HierarchyTreeBuilder(device=self.device)

        self.skip_attributes = skip_attributes
        
    def generate(
        self,
        image_path: str,
        output_dir: str = "./output",
        save_intermediate: bool = True,
        save_images: bool = True
    ) -> Dict[str, Any]:
    
        os.makedirs(output_dir, exist_ok=True)
        start_time = time.time()
        
        original_image = Image.open(image_path).convert("RGB")
        orig_width, orig_height = original_image.size

        max_size = 800
        scale_ratio = 1.0
        if max(orig_width, orig_height) > max_size:
            scale_ratio = max_size / max(orig_width, orig_height)
            new_width = int(orig_width * scale_ratio)
            new_height = int(orig_height * scale_ratio)
            resized_image = original_image.resize((new_width, new_height), Image.LANCZOS)
            
         
            temp_image_path = os.path.join(output_dir, "temp_resized.jpg")
            resized_image.save(temp_image_path)

            detection_image_path = temp_image_path
        else:
           
            detection_image_path = image_path

        stage_start = time.time()
      
        objects_data = self.object_detector.detect(detection_image_path)
        

        if scale_ratio != 1.0:

            objects_data['boxes'] = [
                [coord / scale_ratio for coord in box]
                for box in objects_data['boxes']
            ]
       
            if os.path.exists(temp_image_path):
                os.remove(temp_image_path)
        

        
        if len(objects_data['boxes']) < 2:
            return None
        

        idx_to_sorted_label_pre = self.tree_builder.compute_sorted_label_map(objects_data['labels'])
        

        if not self.skip_attributes:
            stage_start = time.time()

            crop_dir = os.path.join(output_dir, "cropped_objects") if save_images else None
            cropped_objects = self.image_processor.crop_objects(
                image_path,
                objects_data['boxes'],
                objects_data['labels'],
                output_dir=crop_dir,
                sorted_label_map=idx_to_sorted_label_pre
            )
            

            tree_roots, hierarchy_dict, idx_to_sorted_label = self.tree_builder.build_hierarchy_tree(
                objects_data, 
                cropped_objects,
                image_path=image_path
            )

            attributes = self.attribute_generator.generate_attributes(
                objects_data,
                cropped_objects,
                tree_roots
            )
            

        else:

            crop_dir = os.path.join(output_dir, "cropped_objects") if save_images else None
            cropped_objects = self.image_processor.crop_objects(
                image_path,
                objects_data['boxes'],
                objects_data['labels'],
                output_dir=crop_dir,
                sorted_label_map=idx_to_sorted_label_pre
            )
            tree_roots, hierarchy_dict, idx_to_sorted_label = self.tree_builder.build_hierarchy_tree(
                objects_data, 
                cropped_objects,
                image_path=image_path
            )
            attributes = {}  

        

        stage_start = time.time()

        union_dir = os.path.join(output_dir, "union_objects") if save_images else None
        union_images = self.image_processor.generate_union_images(
            image_path,
            objects_data['boxes'],
            objects_data['labels'],
            output_dir=union_dir,
            sorted_label_map=idx_to_sorted_label,
            tree_roots=tree_roots,
            depth_map=getattr(self.tree_builder, '_last_depth_map', None)
        )
        

        relations = self.relation_extractor.extract_relations(
            image_path,
            objects_data,
            cropped_objects,
            union_images,
            attributes,
            idx_to_sorted_label,
            depth_map=getattr(self.tree_builder, '_last_depth_map', None)
        )



        scene_graph = {
            'image_path': image_path,
            'objects': objects_data,
            'hierarchy': hierarchy_dict,
            'attributes': attributes,
            'relations': relations
        }
        

        if save_intermediate:
            img_name = os.path.splitext(os.path.basename(image_path))[0]
            
          
            single_result = [scene_graph]
          
            single_objects = {img_name: objects_data}
            single_hierarchy = {img_name: hierarchy_dict}
            single_attributes = {img_name: attributes}
            single_relations = {img_name: relations}

            self._save_json(single_objects, os.path.join(output_dir, "objects.json"))
            self._save_json(single_hierarchy, os.path.join(output_dir, "hierarchy_mapping.json"))
            self._save_json(single_attributes, os.path.join(output_dir, "attributes.json"))
            self._save_json(single_relations, os.path.join(output_dir, "relations.json"))
            
       
            self._save_json(single_result, os.path.join(output_dir, "scene_graphs.json"))
     
        
        return scene_graph
    
    def generate_batch(
        self,
        image_paths: List[str],
        output_dir: str = "./output",
        save_interval: int = 20,
        save_images: bool = False
    ) -> List[Dict[str, Any]]:
       
        results = []
        os.makedirs(output_dir, exist_ok=True)
        

        if save_images:
            cropped_dir = os.path.join(output_dir, "cropped_objects")
            union_dir = os.path.join(output_dir, "union_objects")
            os.makedirs(cropped_dir, exist_ok=True)
            os.makedirs(union_dir, exist_ok=True)
        
        for idx, image_path in enumerate(tqdm(image_paths)):
            try:

                scene_graph = self.generate(
                    image_path,
                    output_dir=output_dir,
                    save_intermediate=False,
                    save_images=save_images
                )
                
                if scene_graph is not None:
                    results.append(scene_graph)
  
                if (idx + 1) % save_interval == 0:
                    self._save_batch_results(results, output_dir)
       
                    
            except Exception as e:
       
                import traceback
                traceback.print_exc()
                continue
        
 
        remaining_count = len(results) % save_interval
        if remaining_count != 0 or len(results) == 0:
            self._save_batch_results(results, output_dir)
            
        
        return results
    
    def _save_json(self, data: Any, path: str):

        with open(path, 'w', encoding='utf-8') as f:
            json.dump(data, f, indent=2, ensure_ascii=False)
    
    def _save_batch_results(self, results: List[Dict[str, Any]], output_dir: str):
      

        all_objects = {}
        all_hierarchy = {}
        all_attributes = {}
        all_relations = {}
        
        for scene_graph in results:
            img_path = scene_graph['image_path']
            img_name = os.path.splitext(os.path.basename(img_path))[0]
            
            all_objects[img_name] = scene_graph['objects']
            all_hierarchy[img_name] = scene_graph.get('hierarchy', {})
            all_attributes[img_name] = scene_graph.get('attributes', {})
            all_relations[img_name] = scene_graph['relations']
 
        self._save_json(all_objects, os.path.join(output_dir, "objects.json"))
        self._save_json(all_hierarchy, os.path.join(output_dir, "hierarchy_mapping.json"))
        self._save_json(all_attributes, os.path.join(output_dir, "attributes.json"))
        self._save_json(all_relations, os.path.join(output_dir, "relations.json"))
     
        self._save_json(results, os.path.join(output_dir, "scene_graphs.json"))
    
    def cleanup(self):

        self.image_processor.clear_cache()
    
        if torch.cuda.is_available():
            torch.cuda.empty_cache()


def main():

  
    
    parser = argparse.ArgumentParser(description='scene graph generator')
    parser.add_argument('--image', type=str, default=None, help='single image path')
    parser.add_argument('--image_list', type=str, default="/home//hsgg/eval/zeroshotvg.json")
    parser.add_argument('--output_dir', type=str, default='./output/zeroshotvg')
    parser.add_argument('--cuda', type=int, default=6, help='CUDA device number (default: 6)')
    parser.add_argument('--llmdet_model', type=str, 
                       default='/home//hsgg/checkpoint/llmdet',
                       help='detection model path')
    parser.add_argument('--llm_model', type=str,
                       default=None,)
    parser.add_argument('--model_type', type=str,
                       default='llavaonevision1_5',
    )
    
    parser.add_argument('--temperature', type=float, default=1)
    
 
    parser.add_argument('--use_vcd', action='store_true', default=True,
                       )
    parser.add_argument('--no_vcd', action='store_false', dest='use_vcd',)
    parser.add_argument('--cd_alpha', type=float, default=1.0)
    parser.add_argument('--cd_beta', type=float, default=0.1,)
    parser.add_argument('--noise_step', type=int, default=500,)
    parser.add_argument('--amateur_use_image', action='store_true', default=False,)
    
    parser.add_argument('--skip_attributes', action='store_true', default=False,)
    

    parser.add_argument('--save_interval', type=int, default=10,)
    parser.add_argument('--save_images', action='store_true', default=True)
    
    args = parser.parse_args()
    
    generator = SceneGraphGenerator(
        llmdet_model_path=args.llmdet_model,
        llm_model_path=args.llm_model,
        model_type=args.model_type,
        cuda_device=args.cuda,
        temperature=args.temperature,
        use_vcd=args.use_vcd,
        cd_alpha=args.cd_alpha,
        cd_beta=args.cd_beta,
        noise_step=args.noise_step,
        amateur_use_image=args.amateur_use_image,
        skip_attributes=args.skip_attributes
    )
    
    try:
        if args.image:

            scene_graphs = generator.generate(args.image, args.output_dir, save_intermediate=True)
         
            
        elif args.image_list:
       
            with open(args.image_list, 'r') as f:
                image_data = json.load(f)
            
       
            if isinstance(image_data, list):
                if isinstance(image_data[0], str):
                    image_dir="/home//hsgg/dataset/vg/VG_100K"
                    image_paths = [os.path.join(image_dir, img) for img in image_data]
                else:
                    image_paths = [item['image_path'] for item in image_data]
            else:
                raise ValueError("Unsupported image list format")
            image_paths=image_paths
            scene_graphs = generator.generate_batch(
                image_paths, 
                args.output_dir,
                save_interval=args.save_interval,
                save_images=args.save_images
            )
            
    finally:
        generator.cleanup()


if __name__ == '__main__':
    main()



