#!/usr/bin/env python3
"""
Semantic Length Predictor
"""

import torch
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
import json
import re
from typing import Dict, List, Tuple
import argparse

class SemanticLengthPredictor:
    def __init__(self, t2i_model, vl_chat_processor):
        """
        Initialize the semantic length predictor

        Args:
            t2i_model: T2I model instance
            vl_chat_processor: Visual language processor
        """
        self.model = t2i_model
        self.vl_chat_processor = vl_chat_processor
        self.tokenizer = vl_chat_processor.tokenizer

        # Task type configuration
        self.task_profiles = {
            'color': {'base': 300, 'scale': 2.0},
            'position': {'base': 150, 'scale': 1.2},
            'count': {'base': 200, 'scale': 1.5},
            'relation': {'base': 400, 'scale': 2.2},
            'default': {'base': 250, 'scale': 1.8}
        }

    def predict_optimal_length(self, prompt_text: str) -> int:
        """
        Semantic task recognition + length prediction

        Args:
            prompt_text: Input prompt text

        Returns:
            int: Predicted optimal token length
        """
        print(f"🔍 Analyzing task complexity: {prompt_text[:50]}...")

        # Step 1: Semantic task recognition
        task_type = self.semantic_task_classification(prompt_text)
        print(f"📋 Identifying task type: {task_type}")

        # Step 2: Fine-grained length prediction based on task type
        length = self.task_specific_length_prediction(prompt_text, task_type)
        print(f"📊 Predicted optimal length: {length} tokens")

        return length

    def semantic_task_classification(self, prompt_text: str) -> str:
        """
        Using CoT for semantic task classification

        Args:
            prompt_text: Prompt text to be classified

        Returns:
            str: Task type
        """
        classification_prompt = f"""Analyze the semantic type of the following visual task:

Task: {prompt_text}

Optional types:
- color: Color recognition, description, differentiation
- position: Spatial position, coordinates, direction
- count: Quantity statistics, counting
- relation: Object relationship, interaction, scene understanding
- default: Other types

Based on semantic analysis, output the most matching task type.

Semantic type:"""

        try:
            response = self._generate_response(classification_prompt, max_tokens=50)
            task_type = response.strip().lower()

            # Extract task type
            for profile_type in self.task_profiles.keys():
                if profile_type in task_type:
                    return profile_type

            return 'default'

        except Exception as e:
            print(f"Classification failed, using default type: {e}")
            return 'default'

    def task_specific_length_prediction(self, prompt_text: str, task_type: str) -> int:
        """
        Task-specific semantic length prediction

        Args:
            prompt_text: Prompt text
            task_type: Task type

        Returns:
            int: Predicted token length
        """
        prediction_prompt = f"""As an expert in {task_type} tasks, please predict the optimal reasoning length for the following task:

Task: {prompt_text}

Consider factors:
- Task-specific complexity
- Necessary semantic reasoning depth
- Typical thinking pattern of {task_type} tasks

Recommended max_tokens value (based on {task_type} task characteristics):"""

        try:
            prediction = self._generate_response(prediction_prompt, max_tokens=80)
            raw_length = self._extract_semantic_length(prediction)

            # Use task profile as prior constraint
            profile = self.task_profiles[task_type]
            adjusted_length = int(profile['base'] * (raw_length / 512) * profile['scale'])

            # Ensure within reasonable range
            return min(max(adjusted_length, 64), 1024)

        except Exception as e:
            print(f"Length prediction failed, using default value: {e}")
            profile = self.task_profiles[task_type]
            return min(max(profile['base'], 64), 1024)

    def _generate_response(self, prompt: str, max_tokens: int = 100) -> str:
        """
        Using model to generate response

        Args:
            prompt: Input prompt
            max_tokens: Maximum token number

        Returns:
            str: Generated response
        """
        # Build conversation format
        conversation = [
            {
                "role": "User",
                "content": prompt,
            },
            {"role": "Assistant", "content": ""},
        ]

        # Apply SFT template
        sft_format = self.vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
            conversations=conversation,
            sft_format=self.vl_chat_processor.sft_format,
            system_prompt="You are a helpful assistant that analyzes visual tasks.",
        )

        # Tokenize
        inputs = self.vl_chat_processor.tokenizer(
            text=[sft_format],
            return_tensors="pt",
            padding=True,
            padding_side="right",
            add_special_tokens=True
        )

        input_ids = inputs["input_ids"].to('cuda')
        attention_mask = inputs["attention_mask"].to('cuda')

        # Generate response
        with torch.inference_mode():
            outputs = self.model.language_model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=max_tokens,
                do_sample=True,
                temperature=0.7,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )

        # Decode response
        response = self.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
        return response.strip()

    def _extract_semantic_length(self, prediction_text: str) -> int:
        """
        Extract length value from prediction text

        Args:
            prediction_text: Prediction text

        Returns:
            int: Extracted length value
        """
        # Try to extract number
        numbers = re.findall(r'\d+', prediction_text)
        if numbers:
            return int(numbers[0])

        # If no number is found, estimate based on text content
        if 'simple' in prediction_text.lower() or 'easy' in prediction_text.lower():
            return 200
        elif 'complex' in prediction_text.lower() or 'difficult' in prediction_text.lower():
            return 800
        elif 'medium' in prediction_text.lower() or 'moderate' in prediction_text.lower():
            return 500
        else:
            return 400  # Default value

def t2i_r1_adaptive_generation(prompt_text: str, t2i_model, vl_chat_processor, predictor: SemanticLengthPredictor) -> Dict:
    """
        Complete adaptive T2I-R1 generation process

    Args:
        prompt_text: Input prompt
        t2i_model: T2I model
        vl_chat_processor: Visual language processor
        predictor: Semantic length predictor

    Returns:
        Dict: Generated result
    """
    # 1. Semantic-level length prediction
    print("🔍 Semantic analysis task complexity...")
    optimal_tokens = predictor.predict_optimal_length(prompt_text)

    print(f"📊 Predicted optimal length: {optimal_tokens} tokens")

    # 2. Use predicted length for CoT reasoning
    print("🤔 Use predicted length for CoT reasoning...")
    cot_reasoning = _generate_cot_reasoning(prompt_text, t2i_model, vl_chat_processor, optimal_tokens)

    # 3. Generate final answer (fixed short length)
    print("🎯 Generate final answer...")
    final_answer = _generate_final_answer(cot_reasoning, t2i_model, vl_chat_processor)

    return {
        'predicted_length': optimal_tokens,
        'cot_reasoning': cot_reasoning,
        'final_answer': final_answer
    }

def _generate_cot_reasoning(prompt_text: str, t2i_model, vl_chat_processor, max_tokens: int) -> str:
    """Generate CoT reasoning"""
    cot_prompt = f"""Please analyze the complexity and requirements of the following image generation task:

Task: {prompt_text}

Please think from the following aspects:
1. Complexity of visual elements
2. Spatial relationship handling
3. Description of color and material
4. Lighting and shadow effects
5. Overall composition and style

Detailed analysis:"""

    conversation = [
        {
            "role": "User",
            "content": cot_prompt,
        },
        {"role": "Assistant", "content": ""},
    ]

    sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
        conversations=conversation,
        sft_format=vl_chat_processor.sft_format,
        system_prompt="You are a helpful assistant that analyzes visual tasks.",
    )

    inputs = vl_chat_processor.tokenizer(
        text=[sft_format],
        return_tensors="pt",
        padding=True,
        padding_side="right",
        add_special_tokens=True
    )

    input_ids = inputs["input_ids"].to('cuda')
    attention_mask = inputs["attention_mask"].to('cuda')

    with torch.inference_mode():
        outputs = t2i_model.language_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_tokens,
            do_sample=True,
            temperature=0.7,
            pad_token_id=vl_chat_processor.tokenizer.eos_token_id,
            eos_token_id=vl_chat_processor.tokenizer.eos_token_id,
        )

    response = vl_chat_processor.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
    return response.strip()

def _generate_final_answer(cot_reasoning: str, t2i_model, vl_chat_processor) -> str:
    """Generate final answer"""
    final_prompt = f"""Based on the above reasoning, give the final answer: {cot_reasoning}"""

    conversation = [
        {
            "role": "User",
            "content": final_prompt,
        },
        {"role": "Assistant", "content": ""},
    ]

    sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
        conversations=conversation,
        sft_format=vl_chat_processor.sft_format,
        system_prompt="You are a helpful assistant.",
    )

    inputs = vl_chat_processor.tokenizer(
        text=[sft_format],
        return_tensors="pt",
        padding=True,
        padding_side="right",
        add_special_tokens=True
    )

    input_ids = inputs["input_ids"].to('cuda')
    attention_mask = inputs["attention_mask"].to('cuda')

    with torch.inference_mode():
        outputs = t2i_model.language_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=100,  # Fixed short answer
            do_sample=True,
            temperature=0.7,
            pad_token_id=vl_chat_processor.tokenizer.eos_token_id,
            eos_token_id=vl_chat_processor.tokenizer.eos_token_id,
        )

    response = vl_chat_processor.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
    return response.strip()

def main():
    """Test function"""
    parser = argparse.ArgumentParser(description="Semantic Length Predictor")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the model directory")
    parser.add_argument("--prompt", type=str, required=True, help="Prompt to analyze")

    args = parser.parse_args()

    # Load model
    print("Loading model...")
    vl_chat_processor = VLChatProcessor.from_pretrained(args.model_path)
    t2i_model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True)
    t2i_model = t2i_model.to(torch.bfloat16).cuda().eval()

    # Create predictor
    predictor = SemanticLengthPredictor(t2i_model, vl_chat_processor)

    # Test prediction
    result = t2i_r1_adaptive_generation(args.prompt, t2i_model, vl_chat_processor, predictor)

    print("\n" + "="*50)
    print("Generated result:")
    print(f"Predicted length: {result['predicted_length']} tokens")
    print(f"CoT reasoning: {result['cot_reasoning']}")
    print(f"Final answer: {result['final_answer']}")

if __name__ == "__main__":
    main()
