"""
LLM Prompt Templates
Centralized management of all LLM prompts used in the communication system
"""

class PromptTemplates:
    """Centralized prompt template management"""
    
    @staticmethod
    def get_task_context_I_T(task_description, obs_shape, obs_dim_desc, detail_content):
        """I_T: Task objectives and environment characteristics"""
        return f"""
**Task Context - Task Objectives and Environment Characteristics**:

**Task Description**:
{task_description}

**Environment Observation Structure**:
- Observation tensor shape: {obs_shape}
- {obs_dim_desc}
- Each dimension meaning: {detail_content}

**Environment Characteristics**:
- Multi-agent partially observable environment
- Agents must coordinate under incomplete information
- Communication enables sharing of non-locally observable information
"""

    @staticmethod 
    def get_protocol_instructions_I_P(obs_shape, indexing_example, additional_msg_prompt):
        """I_P: Protocol generation instructions for communication design"""
        return f"""
**Protocol Generation Instructions**:

**Communication Design Key Principles**:
1. **Task-Oriented Communication**:
- Explicitly identify observation dimensions crucial to solving the task.
- Messages must clearly relate these dimensions to task objectives.

2. **Uniqueness, Sufficiency & Compactness**:
- Each agent should communicate information that others do not already possess or cannot easily infer based on their own observations.
- Communication should ensure sufficiency, meaning that agents exchange enough information to enable effective inference and coordination under partial observability.
- At the same time, messages should maintain compactness, minimizing redundancy and avoiding the transmission of unnecessary or easily inferable data.

3. **Contextual and Interaction-Aware**:
- Communication should be based on the agent's own observations, actively leveraging behavior-relevant information derived from its "perceived possibilities and recent behavior patterns".

4. **Explicitness and Clarity**:
- Avoid overly abstract messages. All information critical for solving the task must be included explicitly, in a clear and interpretable form.

5. **Structured Output**:
- The final output should be a tensor of shape ({obs_shape.split(',')[0].strip('(')}, {obs_shape.split(',')[1].strip()}, {obs_shape.split(',')[2].strip().rstrip(')')} + message_dim).

6. **Communication Protocol**:
- Specify whether information is exchanged via peer-to-peer (agent-specific) or broadcast (global).
- Each agent should customize received messages based on the context and utility.
- Messages produced by an agent must be distributed and concatenated into the observations of other agents, never into only its own.
- Each message must include a sender identity field (one-hot encoded vector).

7. **Computational Efficiency**:
- No trainable components (e.g., neural networks) in the communication function.

**Observation Access Pattern**:
For example: {indexing_example}

**Protocol Requirements**: 
{additional_msg_prompt}
"""

    @staticmethod
    def get_llm_d_prompt(detail_content_state, task_description):
        """Important state dimensions extraction prompt (from llm_core.py)"""
        return f"""
You are a reasoning agent designing an importance extractor function for multi-agent reinforcement learning.
=========================================================
The agents' task description is:
{task_description}
=========================================================
Important notes:
- The explanation of each state dimension is provided here:
{detail_content_state}
=========================================================
Your task:
Write a Python function named `select_important_state()` that:

**Task-driven Hypothesis (Initial Reasoning)**:  
- Based on the task description and the meaning of each state dimension, form an initial hypothesis about which dimensions are likely important for task success.  
- In this partially observable multi-agent environment, each agent only perceives a limited view of the global state. Therefore, dimensions that are hard to perceive individually but critical when inferred through inter-agent communication should be prioritized. These dimensions are assumed to contribute significantly to coordinated decision-making and ultimately to task success.

Let's think step by step. Below is an illustrative example of the expected output:

```python
import numpy
def select_important_state():
    # Your implementation here
    # A brief explanation as an in-code comment about why you selected these dimensions
    return important_dims  # e.g., [idx,...]
```
"""

    @staticmethod
    def get_phase_0_protocol_prompt(important_dims, task_description=None,task_additional_description=None,detail_content=None, 
                                   obs_shape=None, obs_dim_desc=None, indexing_example=None, 
                                   additional_msg_prompt=""):

        if task_description is None:
            from .env_utils import EnvUtils

            task_description = "Multi-agent coordination task"
            detail_content = "Agent observations and environment state information"
            obs_shape = "(batch_size, n_agents, obs_dim)"
            obs_dim_desc = "Batch size, number of agents, observation dimensions"
            indexing_example = "o[batch_idx, agent_idx, :] represents agent observation"
        
        I_T = PromptTemplates.get_task_context_I_T(task_description, obs_shape, obs_dim_desc, detail_content)
        I_P = PromptTemplates.get_protocol_instructions_I_P(obs_shape, indexing_example, additional_msg_prompt)
        
        return f"""
You are a communication design agent for Multi-Agent Reinforcement Learning (MARL).
Your goal is to design a task-specific communication protocol that allows agents to share only essential and non-redundant information to enhance coordination and decision-making.
Based on the task description and observation dimensions, identify which information should be exchanged and structure it to maximize task performance.
=========================================================
Task Description :
=========================================================

{I_T}

**Reasoning Tokens - Important State Dimensions**:
Based on previous analysis, the following state dimensions were identified as critical:
{important_dims}
{task_additional_description}
These dimensions require inter-agent communication for effective coordination under partial observability.

=========================================================
PROTOCOL GENERATION INSTRUCTIONS:
=========================================================

{I_P}

Using the important state dimensions reasoning Tokens above, design a communication protocol that enables agents to share information about these critical dimensions to improve coordination.

You are required to create two Python functions:

1. `message_design_instruction()`:
- Clearly describes how the message content is constructed based on the important state dimensions and observation context

2. `communication(o)`:
- Input: Observation tensor `o`
- Output: Enhanced observation tensor with integrated task-specific messages
- Focus on communicating information related to the important dimensions identified above

Both functions must be executable and ready for direct integration with MARL algorithms.
Caution!: Create Python functions that **minimizes the use of "for loops"** when handling batch processing to optimize computational efficiency.
{additional_msg_prompt}

Let's think step by step. Below is an illustrative example of the expected output:

```python
import torch as th
def message_design_instruction():
    # Explain how this protocol helps agents coordinate using these critical dimensions
    return message_description

def communication(o):
    # Your communication implementation focusing on important dimensions
    # You should design the communication protocol based on the message design instruction
    # use same device as input to avoid CUDA/CPU mismatch
    return messages_o
```
"""

    @staticmethod 
    def get_llm_m_update_prompt(task_description, detail_content, obs_shape, obs_dim_desc, 
                               indexing_example, message_concat_axis, timestep_additional_prompt, 
                               task_additional_prompt, additional_msg_prompt, feedback):
        
        I_T = PromptTemplates.get_task_context_I_T(task_description, obs_shape, obs_dim_desc, detail_content)
        I_P = PromptTemplates.get_protocol_instructions_I_P(obs_shape, indexing_example, additional_msg_prompt)
        
        return f"""
You are a communication design agent for Multi-Agent Reinforcement Learning (MARL).
Your goal is to design a task-specific communication protocol that allows agents to share only essential and non-redundant information to enhance coordination and decision-making.
Based on the task description and observation dimensions, identify which information should be exchanged and structure it to maximize task performance.
=========================================================
Task Description :
=========================================================

{I_T}

=========================================================
PROTOCOL GENERATION INSTRUCTIONS:
=========================================================

{I_P}

You are required to create two Python functions:

1. `message_design_instruction()`:
- Clearly describes how the message content is constructed based on the important state dimensions and observation context

2. `communication(o)`:
- Input: Observation tensor `o`
- Output: Enhanced observation tensor with integrated task-specific messages

Both functions must be executable and ready for direct integration with MARL algorithms.
Caution!: Create two Python functions that **minimizes the use of "for loops"** when handling batch processing to optimize computational efficiency.

=========================================================
Here is the feedback from the previous communication protocol evaluation:
{feedback}

Reflect the feedback by designing messages from each agent’s own observations, prioritizing information that enables all agents to achieve consistent state prediction and shared understanding.

{timestep_additional_prompt}
=========================================================
{task_additional_prompt}
{additional_msg_prompt}
Let's think step by step. Below is an illustrative example of the expected output:

```python
import torch as th

def message_design_instruction():
    # Your message design instruction goes here
    return message_description

def communication(o):
    # input : {obs_shape}
    # Your communication implementation goes here
    # use same device as input to avoid CUDA/CPU mismatch
    # {message_concat_axis}
    # Strict Rule : Only concatenate new, non-overlapping fields into each agent’s observation; exclude any information already included in the previous protocol.
    return messages_o 
```
"""

    @staticmethod
    def get_analysis_prompt(analysis_data, task_description, detail_content, obs_shape, 
                          obs_tensor_desc, obs_example, predictability_calc, 
                          timewise_additional_prompt, next_k_input_data, 
                          task_additional_description, cur_communication_method, json_data,
                          phase_info=None):

        x= PromptTemplates.get_task_context_I_T(f"{task_description}", obs_shape, obs_tensor_desc, detail_content)
        
        extras = []
        if timewise_additional_prompt and timewise_additional_prompt.strip():
            extras.append(timewise_additional_prompt.strip())
        if next_k_input_data and next_k_input_data.strip():
            extras.append(next_k_input_data.strip())
        if task_additional_description and task_additional_description.strip():
            extras.append(task_additional_description.strip())
        extras_block = "\n".join(extras)
    
        goals = "\n".join(f"- {g}" for g in phase_info.get("enhancement_goals", []))
        phase_instruction = (
            f"PHASE {phase_info['phase_num']}: {phase_info['phase_name']}\n"
            f"Goal: {phase_info['phase_goal']}\n"
            f"Objective: {phase_info['phase_instruction']}\n"
            f"Instruction:\n{phase_info['specific_instruction']}\n"
            f"Focus Areas:\n{goals}"
            f"{extras_block}"
        )
        
        meta_information = (
        "**Important State Dimensions Performance**:\n"
        f"{json_data}\n\n"
        "**Discriminator Evaluation Method**:\n"
        f"{predictability_calc}"
    )
        
        if phase_info is None:
            phase_info = PromptTemplates.get_phase1_info()
        
        return f"""
You are an analysis agent tasked with improving communication strategies in a multi-agent reinforcement learning (MARL) system.
{x}
=========================================================
{phase_instruction}
=========================================================
**Previous Protocol Under Analysis**:
{cur_communication_method}
=========================================================
**DISCRIMINATOR EVALUATION RESULTS**:
{meta_information}
=========================================================
**Analysis Context**:
- Each agent combines its own local observation with received messages to infer important state dimensions
- You are analyzing predictability results showing how accurately each agent can infer critical dimensions
- Performance differences across agents indicate areas where communication protocol needs improvement

Expected Output Format (JSON):
{{
  "Evaluation": "Assessment of current communication effectiveness from {phase_info['phase_name'].lower()} perspective, identifying performance gaps and protocol limitations.",
  "Missing_Information_Hypothesis": "Hypothesis about what information is missing or inadequately communicated, focusing on {phase_info['phase_name'].lower()} requirements.",
  "Improvement_Suggestions": "Specific suggestions to improve communication content and structure based on {phase_info['phase_name'].lower()} analysis results."
}}
"""
    @staticmethod
    def get_analysis_prompt_phase1(analysis_data, task_description, detail_content, obs_shape, 
                                  obs_tensor_desc, obs_example, predictability_calc, 
                                  timewise_additional_prompt, next_k_input_data, 
                                  task_additional_description, cur_communication_method, json_data):
        """Phase 1: Recognition Enhancement Analysis"""
        return PromptTemplates.get_analysis_prompt(
            analysis_data, task_description, detail_content, obs_shape, 
            obs_tensor_desc, obs_example, predictability_calc, 
            timewise_additional_prompt, next_k_input_data, 
            task_additional_description, cur_communication_method, json_data,
            phase_info=PromptTemplates.get_phase1_info()
        )
    
    @staticmethod
    def get_analysis_prompt_phase2(analysis_data, task_description, detail_content, obs_shape, 
                                  obs_tensor_desc, obs_example, predictability_calc, 
                                  timewise_additional_prompt, next_k_input_data, 
                                  task_additional_description, cur_communication_method, json_data):
        """Phase 2: Sharing Enhancement Analysis"""
        return PromptTemplates.get_analysis_prompt(
            analysis_data, task_description, detail_content, obs_shape, 
            obs_tensor_desc, obs_example, predictability_calc, 
            timewise_additional_prompt, next_k_input_data, 
            task_additional_description, cur_communication_method, json_data,
            phase_info=PromptTemplates.get_phase2_info()
        )

    @staticmethod
    def get_error_augmentation_prompt(base_prompt, attempt, stage, exc, short_tb):
        return (
            f"{base_prompt}\n\n"
            f"---\n"
            f"[Retry context] Previous attempt #{attempt} FAILED\n"
            f"Stage: {stage}\n"
            f"Exception: {type(exc).__name__}\n"
            f"Details (last lines):\n```\n{short_tb}\n```\n"
            "Please fix the issue above. Output ONLY a single Python fenced block:\n"
            "```python\n# your fixed code\n```\n"
            "Requirements: provide functions `communication(o)` and `message_design_instruction()`; "
            "no trainable params; respect the required tensor shapes; avoid for-loops over batch/time dims."
        )
    
    @staticmethod
    def get_phase1_info():
        """Phase 1: Recognition Enhancement"""
        return {
            'phase_num': 1,
            'phase_name': 'RECOGNITION ENHANCEMENT',
            'phase_goal': 'Ensure each agent can recover state dimensions accurately',
            'feedback_instruction_label': 'x̃^(1)',
            'feedback_source': 'Initial Analysis Phase',
            'phase_instruction': 'Individual agent enhancement through improved communication',
            'specific_instruction': """Considering the provided feedback, you should refine the communication approach to reduce uneven prediction across agents and timesteps. When certain agents can infer a state dimension while others cannot, this asymmetry must be explicitly addressed. The protocol should therefore include positional and behavioral cues that not only convey task-relevant information but also make clear which agents have reliable knowledge and which do not, enabling the group to synchronize understanding and eliminate inconsistent inference.""",
            'enhancement_goals': [
                'Improve individual agent prediction accuracy for critical state dimensions',
                'Ensure each agent receives information needed for better state inference',
                'Address agent-specific limitations in state dimension recognition while avoiding over-customized peer-specific messages, ensuring broadly useful information that supports consistent state inference',
                'Enhance communication content to support individual agent observation utilization',
                'Focus on helping agents better recognize and interpret shared information'
            ]
        }
    
    @staticmethod
    def get_phase2_info():
        """Phase 2: Sharing Enhancement """
        return {
            'phase_num': 2,
            'phase_name': 'SHARING ENHANCEMENT',
            'phase_goal': 'Achieve consistent state recovery across all agents',
            'feedback_instruction_label': 'x̃^(2)',
            'feedback_source': 'Recognition Enhancement Phase',
            'phase_instruction': 'Address inconsistencies from phase 1 analysis',
            'specific_instruction': """Considering the provided feedback, you should improve upon the current communication approach by emphasizing the agents' positional information and behavioral patterns to enhance prediction capabilities and coordination.
You must reflect this feedback in your communication design by including novel, task-relevant information—such as movement intent or coordination cues—that addresses weakly predictable state dimensions and cannot be inferred from local observations.""",
            'enhancement_goals': [
                'If one agent can recover a state dimension, all agents should be able to do so consistently',
                'Address information imbalance across agents that leads to prediction disagreements',
                'Promote uniformity in state representations while maintaining efficiency',
                'Focus on dimensions where agents show inconsistent prediction performance'
            ]
        }
