"""
PointLLM Provider for 3D Point Cloud Analysis

This provider handles the loading and inference of PointLLM models for analyzing
3D point clouds and providing textual feedback about geometric structures.
"""

import os
import logging
import warnings
import numpy as np
import torch
from typing import Dict, Any, Optional, Tuple

# Suppress the noisy FutureWarning emitted by transformers when loading
# trusted local checkpoints with torch.load(weights_only=False).
warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    message="You are using `torch.load` with `weights_only=False`",
    module="transformers.modeling_utils"
)

# Simple in-memory cache so multiple providers reuse the same loaded model
# instead of hitting disk/GPU repeatedly. Keyed by (model_path, dtype, device).
_POINTLLM_MODEL_CACHE: Dict[Tuple[str, str, str], Dict[str, Any]] = {}

# Import PointLLM components
try:
    import sys
    # Add PointLLM to path
    if './PointLLM' not in sys.path:
        sys.path.append('./PointLLM')

    from transformers import AutoTokenizer
    from pointllm.model import PointLLMLlamaForCausalLM
    from pointllm.conversation import conv_templates, SeparatorStyle
    from pointllm.utils import disable_torch_init
    from pointllm.model.utils import KeywordsStoppingCriteria
    POINTLLM_AVAILABLE = True
except ImportError as e:
    POINTLLM_AVAILABLE = False
    print(f"PointLLM import error: {e}")


class PointLLMProvider:
    """
    Provider for PointLLM model inference.

    Handles model loading, point cloud preprocessing, and text generation
    for 3D shape analysis.
    """

    def __init__(self, config: Dict[str, Any]):
        """
        Initialize the PointLLM provider.

        Args:
            config: Configuration dictionary containing model settings
        """
        self.logger = logging.getLogger(self.__class__.__name__)
        self.config = config.get('pointllm_critic', {})

        if not POINTLLM_AVAILABLE:
            raise RuntimeError("PointLLM is not installed. Please install it first.")

        # Model configuration - support both HuggingFace ID and local path
        default_local_path = './models/PointLLM_7B_v1.2'
        default_hf_path = 'RunsenXu/PointLLM_7B_v1.2'

        # Check if local model exists, otherwise use HuggingFace
        if os.path.exists(default_local_path):
            self.model_path = self.config.get('model_path', default_local_path)
            self.logger.info(f"Using local model at: {self.model_path}")
        else:
            self.model_path = self.config.get('model_path', default_hf_path)
            self.logger.info(f"Using HuggingFace model: {self.model_path}")

        self.torch_dtype = self.config.get('torch_dtype', 'float16')

        # Determine torch dtype
        dtype_mapping = {
            "float32": torch.float32,
            "float16": torch.float16,
            "bfloat16": torch.bfloat16,
        }
        self.dtype = dtype_mapping.get(self.torch_dtype, torch.float16)

        # Force GPU usage - we have 24GB VRAM available
        if torch.cuda.is_available():
            self.device = torch.device("cuda:0")  # Use first GPU explicitly
            self.logger.info(f"CUDA available. Using GPU 0")
        else:
            self.device = torch.device("cpu")
            self.logger.warning("CUDA not available, falling back to CPU")

        self.model = None
        self.tokenizer = None
        self.point_backbone_config = None
        self.conv_template = None
        self.keywords = None
        self.mm_use_point_start_end = False

        # Load model immediately to ensure it's available
        self._model_loaded = False
        self.load_model()

    def load_model(self):
        """Load the PointLLM model and tokenizer."""
        if self._model_loaded:
            return

        try:
            cache_key = (self.model_path, self.torch_dtype, self.device.type)

            if cache_key in _POINTLLM_MODEL_CACHE:
                cached = _POINTLLM_MODEL_CACHE[cache_key]
                self.logger.info(f"Reusing cached PointLLM model from {self.model_path}")
                self.model = cached['model']
                self.tokenizer = cached['tokenizer']
                self.point_backbone_config = cached['point_backbone_config']
                self.mm_use_point_start_end = cached['mm_use_point_start_end']
                conv_mode = cached['conv_mode']
                self.conv_template = conv_templates[conv_mode].copy()
                stop_str = self.conv_template.sep if self.conv_template.sep_style != SeparatorStyle.TWO else self.conv_template.sep2
                self.keywords = [stop_str]
                self._model_loaded = True
                return

            self.logger.info(f"Loading PointLLM model from {self.model_path}")

            # Disable torch initialization for faster loading
            disable_torch_init()

            # Load tokenizer
            self.logger.info("Loading tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)

            # Load PointLLM model directly to GPU
            self.logger.info(f"Loading model with dtype {self.torch_dtype} directly to {self.device}...")

            # Set device_map for direct GPU loading
            if self.device.type == "cuda":
                # Load directly to GPU to avoid CPU->GPU transfer
                self.model = PointLLMLlamaForCausalLM.from_pretrained(
                    self.model_path,
                    device_map={"":0},  # Load all to GPU 0
                    low_cpu_mem_usage=True,
                    use_cache=True,
                    torch_dtype=self.dtype
                )
                self.logger.info(f"Model loaded directly to GPU. Memory allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB")
            else:
                # CPU fallback
                self.model = PointLLMLlamaForCausalLM.from_pretrained(
                    self.model_path,
                    low_cpu_mem_usage=True,
                    use_cache=True,
                    torch_dtype=self.dtype
                )

            # Initialize tokenizer and point backbone config
            self.model.initialize_tokenizer_point_backbone_config_wo_embedding(self.tokenizer)

            self.model.eval()

            # Get configuration
            self.mm_use_point_start_end = getattr(self.model.config, "mm_use_point_start_end", False)
            self.point_backbone_config = self.model.get_model().point_backbone_config

            # Set up conversation template
            if self.mm_use_point_start_end:
                if "v1" in self.model_path.lower():
                    conv_mode = "vicuna_v1_1"
                else:
                    conv_mode = "vicuna_v1_1"  # Default to v1

                self.conv_template = conv_templates[conv_mode].copy()
            else:
                # Default conversation template
                conv_mode = "vicuna_v1_1"
                self.conv_template = conv_templates[conv_mode].copy()

            # Set up stopping criteria
            stop_str = self.conv_template.sep if self.conv_template.sep_style != SeparatorStyle.TWO else self.conv_template.sep2
            self.keywords = [stop_str]

            # Store in cache for subsequent providers
            _POINTLLM_MODEL_CACHE[cache_key] = {
                'model': self.model,
                'tokenizer': self.tokenizer,
                'point_backbone_config': self.point_backbone_config,
                'mm_use_point_start_end': self.mm_use_point_start_end,
                'conv_mode': conv_mode,
            }

            self._model_loaded = True
            self.logger.info(f"PointLLM model loaded successfully on {self.device}")

        except Exception as e:
            self.logger.error(f"Failed to load PointLLM model: {e}")
            self.logger.info("Falling back to stub provider mode")
            self._model_loaded = False
            raise

    def analyze_point_cloud(self,
                           point_cloud: np.ndarray,
                           prompt: str,
                           color_mapping: Optional[Dict[str, str]] = None,
                           max_length: int = 1024) -> str:
        """
        Analyze a point cloud and generate textual feedback.

        Args:
            point_cloud: Array of shape (N, 6) with xyz + RGB values
            prompt: Text prompt describing what to analyze
            color_mapping: Optional mapping of color names to component names
            max_length: Maximum length of generated text

        Returns:
            Generated text analysis of the point cloud
        """
        if not self._model_loaded:
            self.load_model()

        try:
            # Prepare point cloud tensor with correct dtype
            point_cloud_tensor = torch.from_numpy(point_cloud).unsqueeze(0)
            if self.device.type == "cuda":
                point_cloud_tensor = point_cloud_tensor.cuda().to(self.dtype)
            else:
                point_cloud_tensor = point_cloud_tensor.to(torch.float32)

            # Get point token configuration
            point_token_len = self.point_backbone_config['point_token_len']
            default_point_patch_token = self.point_backbone_config['default_point_patch_token']
            default_point_start_token = self.point_backbone_config.get('default_point_start_token', '')
            default_point_end_token = self.point_backbone_config.get('default_point_end_token', '')

            # Add color mapping info to prompt if provided
            if color_mapping:
                color_info = "\nThe point cloud uses the following color mapping:\n"
                for color_name, component_name in color_mapping.items():
                    color_info += f"- {color_name}: {component_name}\n"
                prompt = prompt + color_info

            # Reset conversation
            self.conv_template.reset()

            # Prepare the question with point tokens
            if self.mm_use_point_start_end:
                question = default_point_start_token + default_point_patch_token * point_token_len + default_point_end_token + '\n' + prompt
            else:
                question = default_point_patch_token * point_token_len + '\n' + prompt

            # Add to conversation
            self.conv_template.append_message(self.conv_template.roles[0], question)
            self.conv_template.append_message(self.conv_template.roles[1], None)

            # Get prompt and tokenize
            full_prompt = self.conv_template.get_prompt()
            inputs = self.tokenizer([full_prompt])
            input_ids = torch.as_tensor(inputs.input_ids)

            if self.device.type == "cuda":
                input_ids = input_ids.cuda()

            # Set up stopping criteria
            stopping_criteria = KeywordsStoppingCriteria(self.keywords, self.tokenizer, input_ids)
            stop_str = self.keywords[0]

            # Generate response
            with torch.inference_mode():
                output_ids = self.model.generate(
                    input_ids,
                    point_clouds=point_cloud_tensor,
                    do_sample=False,
                    temperature=None,
                    top_p=1.0,
                    repetition_penalty=1.05,
                    max_new_tokens=min(max_length, 768),
                    min_new_tokens=min(256, max(64, max_length // 4)),
                    stopping_criteria=[stopping_criteria],
                    pad_token_id=self.tokenizer.eos_token_id
                )

            # Decode response
            input_token_len = input_ids.shape[1]
            outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
            outputs = outputs.strip()

            if outputs.endswith(stop_str):
                outputs = outputs[:-len(stop_str)]
            outputs = outputs.strip()

            return outputs

        except Exception as e:
            self.logger.error(f"Failed to analyze point cloud: {e}")
            raise  # Re-raise exception instead of fallback

    def _create_statistical_prompt(self, point_cloud: np.ndarray, base_prompt: str) -> str:
        """
        Create an enhanced prompt with point cloud statistics.

        Args:
            point_cloud: Array of shape (N, 6) with xyz + RGB
            base_prompt: Original prompt

        Returns:
            Enhanced prompt with statistics
        """
        # Calculate statistics
        xyz = point_cloud[:, :3]
        rgb = point_cloud[:, 3:]

        stats = {
            'num_points': len(point_cloud),
            'xyz_min': xyz.min(axis=0).tolist(),
            'xyz_max': xyz.max(axis=0).tolist(),
            'xyz_center': xyz.mean(axis=0).tolist(),
            'xyz_spread': (xyz.max(axis=0) - xyz.min(axis=0)).tolist(),
            'num_colors': len(np.unique(rgb, axis=0))
        }

        # Create enhanced prompt
        enhanced = f"{base_prompt}\n\n"
        enhanced += "Point Cloud Statistics:\n"
        enhanced += f"- Number of points: {stats['num_points']}\n"
        enhanced += f"- Spatial extent: {stats['xyz_spread']}\n"
        enhanced += f"- Number of distinct colors: {stats['num_colors']}\n"
        enhanced += f"- Center position: {stats['xyz_center']}\n"

        return enhanced

    def unload_model(self):
        """Unload the model to free memory."""
        if self._model_loaded:
            cache_key = (self.model_path, self.torch_dtype, self.device.type)
            _POINTLLM_MODEL_CACHE.pop(cache_key, None)
            del self.model
            del self.tokenizer

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            self._model_loaded = False
            self.logger.info("PointLLM model unloaded")


# Simplified stub for testing without actual PointLLM model
class PointLLMProviderStub:
    """
    Stub implementation for testing without loading the actual model.
    Provides simulated responses based on point cloud statistics.
    """

    def __init__(self, config: Dict[str, Any]):
        self.logger = logging.getLogger(self.__class__.__name__)
        self.config = config

    def analyze_point_cloud(self,
                           point_cloud: np.ndarray,
                           prompt: str,
                           color_mapping: Optional[Dict[str, str]] = None,
                           max_length: int = 512) -> str:
        """Generate a simulated analysis based on point cloud statistics."""

        # Analyze point cloud
        xyz = point_cloud[:, :3]
        rgb = point_cloud[:, 3:]

        # Find unique colors (components)
        unique_colors = np.unique(rgb, axis=0)
        num_components = len(unique_colors)

        # Generate simulated feedback
        response = "Based on 3D point cloud analysis:\n\n"
        response += f"The object consists of {num_components} distinct components.\n"

        # Simulate component analysis
        if num_components > 5:
            response += "- Complex multi-part structure detected\n"
            response += "- Some components appear disconnected from the main body\n"
            response += "- Recommend checking physical connections between parts\n"
        elif num_components > 2:
            response += "- Well-defined multi-component structure\n"
            response += "- Components appear properly connected\n"
            response += "- Good structural integrity observed\n"
        else:
            response += "- Simple structure with few components\n"
            response += "- May be missing expected parts\n"

        # Add color-based observations
        if color_mapping:
            response += "\nComponent-specific observations:\n"
            for color_name, component_name in list(color_mapping.items())[:3]:
                response += f"- {component_name} ({color_name}): Properly positioned and sized\n"

        return response

    def load_model(self):
        """No-op for stub."""
        pass

    def unload_model(self):
        """No-op for stub."""
        pass
