"""
Incremental summarizer with multiple methods based on the paper.
"""

import json
import logging
from typing import Dict, Union, List, Tuple
from utils import (
    create_client,
    time_llm_call,
    parse_to_summary,
    parse_to_cok,
    merge_summary_objects,
    apply_cok_operations,
    robust_json_extraction,
    clean_dspy_response,
)
from schema import CoK, Summary


class IncrementalSummarizer:
    """Incremental summarizer supporting multiple methods."""

    def __init__(self, config, model, verbosity="medium"):
        self.logger = logging.getLogger(__name__)
        self.verbosity = verbosity.lower()

        if self.verbosity not in ["low", "medium", "high"]:
            self.verbosity = "medium"

        self.config = config
        self.client = create_client(model)
        self.deployment = model["model_name"]
        self.provider = model.get("provider", "ollama")

        # Check if this is gpt-oss:20b model that needs special JSON handling
        self.parse_json_manually = "gpt-oss" in self.deployment.lower()

        self._log_info(
            f"Initialized with model: {model.get('model_name', 'unknown')}, provider: {self.provider}"
        )
        if self.parse_json_manually:
            self._log_info("Enabled robust JSON parsing for gpt-oss model")

    def _log_debug(self, message: str):
        """Log debug message based on verbosity level."""
        if self.verbosity in ["medium", "high"]:
            self.logger.debug(message)

    def _log_info(self, message: str):
        """Log info message based on verbosity level."""
        if self.verbosity in ["low", "medium", "high"]:
            self.logger.info(message)

    def _log_verbose(self, message: str):
        """Log verbose message only for high verbosity."""
        if self.verbosity == "high":
            self.logger.info(f"[VERBOSE] {message}")

    def _log_prompt_and_response(self, prompt: str, response: str, context: str = ""):
        """Log full prompt and response for high verbosity."""
        if self.verbosity == "high":
            self.logger.info(f"[VERBOSE] {context} PROMPT:\n{prompt}")
            self.logger.info(f"[VERBOSE] {context} RESPONSE:\n{response}")

    def _log_summary_stats(self, summary: Dict, context: str = ""):
        """Log summary statistics for low and medium verbosity."""
        if self.verbosity in ["low", "medium"]:
            attr_count = len(summary.get("attributes", {}))
            self.logger.info(
                f"{context} Summary generated with {attr_count} attributes"
            )
        elif self.verbosity == "high":
            self.logger.info(
                f"[VERBOSE] {context} Full summary: {json.dumps(summary, indent=2)}"
            )

    def _parse_llm_response(
        self, response: str, format_type: str = "summary"
    ) -> Union[Dict, object]:
        """Parse LLM response with model-specific handling."""
        if self.parse_json_manually:
            self._log_debug("Applying robust JSON extraction for gpt-oss model")
            response = robust_json_extraction(response)

        if format_type == "summary":
            return parse_to_summary(response)
        elif format_type == "cok":
            return parse_to_cok(response)
        else:
            raise ValueError(f"Unknown format_type: {format_type}")

    @time_llm_call
    def call_llm(self, messages: List[Dict[str, str]], **kwargs):
        """Call the LLM."""
        self._log_debug(
            f"Calling LLM with {len(messages)} messages using {self.provider}"
        )

        # Log full prompt for high verbosity
        if self.verbosity == "high" and messages:
            prompt_text = "\n".join(
                [f"{msg['role']}: {msg['content']}" for msg in messages]
            )
            self._log_verbose(f"LLM INPUT:\n{prompt_text}")

        if self.provider == "ollama":
            # Use ollama client
            response = self.client.chat(
                model=self.deployment,
                messages=messages,
                options={"temperature": 0, "num_ctx": 16384},
                **kwargs,
            )
            content = response["message"]["content"]
        elif self.provider == "azure":
            # Use OpenAI-compatible client
            if "response_format" in kwargs:
                response = self.client.beta.chat.completions.parse(
                    model=self.deployment, messages=messages, **kwargs
                )
            else:
                response = self.client.chat.completions.create(
                    model=self.deployment, messages=messages, **kwargs
                )
            content = response.choices[0].message.content

        # Log response based on verbosity
        if self.verbosity == "low":
            self._log_info(f"LLM response received ({len(content)} chars)")
        elif self.verbosity == "medium":
            self._log_debug(f"LLM response length: {len(content)} chars")
        elif self.verbosity == "high":
            self._log_verbose(f"LLM OUTPUT:\n{content}")

        return content

    def dspy_generate_concat(
        self, notes: Dict[str, str], entity_name: str = "ENTITY"
    ) -> str:
        """Generate update-concat using dspy."""
        system_prompt = self.config["prompts"]["dspy_generate_concat"]["system_prompt"]
        combined_notes = "\n\n".join(
            [f"P{i + 1}. {note}" for i, note in enumerate(notes.values())]
        )
        user_prompt = self.config["prompts"]["dspy_generate_concat"][
            "user_prompt"
        ].format(note=combined_notes)

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]

        response = self.call_llm(messages)
        cleaned_response = clean_dspy_response(response)

        return json.dumps(cleaned_response, indent=2)

    def dspy_generate_update(
        self, notes: Dict[str, str], entity_name: str = "ENTITY"
    ) -> str:
        """Generate update using dspy."""
        system_prompt = self.config["prompts"]["dspy_generate_update"]["system_prompt"]
        messages = []
        responses = {}

        for note_id, note in notes.items():
            user_prompt = self.config["prompts"]["dspy_generate_update"][
                "user_prompt"
            ].format(note=note)

            messages.append({"role": "system", "content": system_prompt})
            messages.append({"role": "user", "content": user_prompt.format(note=note)})

            response = self.call_llm(messages)
            cleaned_response = clean_dspy_response(response)

            messages.pop()
            messages.pop()
            messages.append({"role": "assistant", "content": cleaned_response})
            responses[note_id] = cleaned_response

        return cleaned_response, responses

    def generate_once(self, notes: Dict[str, str], entity_name: str = "ENTITY") -> Dict:
        """Generate summary once from all notes."""
        self._log_info(
            f"Starting generate_once for {len(notes)} notes (entity: {entity_name})"
        )

        # Combine all notes
        combined_notes = "\n\n".join(
            [f"P{i + 1}. {note}" for i, note in enumerate(notes.values())]
        )
        self._log_debug(f"Combined notes length: {len(combined_notes)} chars")

        user_prompt = self.config["prompts"]["generate_once"]["initial_prompt"].format(
            entity_name=entity_name, paragraph=combined_notes
        )

        messages = [{"role": "user", "content": user_prompt}]

        if self.provider == "ollama" and not self.parse_json_manually:
            # Use JSON schema format for Ollama (except for gpt-oss models)
            response = self.call_llm(messages, format=Summary.model_json_schema())
            summary = self._parse_llm_response(response, "summary")
        else:
            # Use manual parsing for other providers or gpt-oss models
            if not self.parse_json_manually:
                response = self.call_llm(
                    messages,
                    response_format={
                        "type": "json_schema",
                        "json_schema": {
                            "name": "Summary",
                            "schema": Summary.model_json_schema(),
                        },
                    },
                )
            else:
                response = self.call_llm(messages)
            summary = self._parse_llm_response(response, "summary")

        self._log_summary_stats(summary, "generate_once:")
        return summary

    def generate_update(
        self,
        notes: Dict[str, str],
        entity_name: str = "ENTITY",
    ) -> Dict:
        """Generate incremental updates to existing summary."""
        self._log_info(
            f"Starting generate_update for {len(notes)} notes (entity: {entity_name})"
        )

        # Generate initial summary if no previous summary
        first_note = next(iter(notes.values()))
        self._log_debug(f"Starting with initial note, length: {len(first_note)} chars")

        initial_prompt = self.config["prompts"]["generate_update"][
            "initial_prompt"
        ].format(entity_name=entity_name, paragraph=f"P1. {first_note}")

        messages = [{"role": "user", "content": initial_prompt}]

        if self.provider == "ollama" and not self.parse_json_manually:
            response = self.call_llm(messages, format=Summary.model_json_schema())
            current_summary = self._parse_llm_response(response, "summary")
        else:
            if not self.parse_json_manually:
                response = self.call_llm(
                    messages,
                    response_format={
                        "type": "json_schema",
                        "json_schema": {
                            "name": "Summary",
                            "schema": Summary.model_json_schema(),
                        },
                    },
                )
            else:
                response = self.call_llm(messages)
            current_summary = self._parse_llm_response(response, "summary")

        self._log_verbose(
            f"Initial summary keys: {list(current_summary['attributes'].keys())}"
        )

        # Remove the first note from processing
        remaining_notes = dict(list(notes.items())[1:])
        self._log_info(
            f"Processing {len(remaining_notes)} remaining notes incrementally"
        )

        # Process remaining notes incrementally
        for i, (_, note_content) in enumerate(remaining_notes.items()):
            self._log_debug(f"Processing incremental note {i + 2}/{len(notes)}")

            update_prompt = self.config["prompts"]["generate_update"][
                "update_prompt"
            ].format(
                entity_name=entity_name,
                paragraph=f"P{i + 2}. {note_content}",
                existing_summary=json.dumps(current_summary, indent=2),
            )
            messages = [{"role": "user", "content": update_prompt}]

            if self.provider == "ollama" and not self.parse_json_manually:
                response = self.call_llm(messages, format=Summary.model_json_schema())
                current_summary = self._parse_llm_response(response, "summary")
            else:
                if not self.parse_json_manually:
                    response = self.call_llm(
                        messages,
                        response_format={
                            "type": "json_schema",
                            "json_schema": {
                                "name": "Summary",
                                "schema": Summary.model_json_schema(),
                            },
                        },
                    )
                else:
                    response = self.call_llm(messages)
                current_summary = self._parse_llm_response(response, "summary")

        self._log_summary_stats(current_summary, "generate_update:")
        return current_summary

    def generate_merge(
        self, notes: Dict[str, str], entity_name: str = "ENTITY"
    ) -> Tuple[Dict, List[Dict]]:
        """Generate partial summaries then merge them."""
        self._log_info(
            f"Starting generate_merge for {len(notes)} notes (entity: {entity_name})"
        )

        partial_prompt = self.config["prompts"]["generate_merge"]["partial_prompt"]

        # Generate partial summaries
        partial_summaries = []
        self._log_info("Generating partial summaries for each note")

        for i, (_, note_content) in enumerate(notes.items()):
            self._log_debug(f"Creating partial summary {i + 1}/{len(notes)}")

            user_prompt = partial_prompt.format(
                entity_name=entity_name, paragraph=f"P{i + 1}. {note_content}"
            )

            messages = [{"role": "user", "content": user_prompt}]

            if self.provider == "ollama" and not self.parse_json_manually:
                response = self.call_llm(messages, format=Summary.model_json_schema())
                summary_obj = self._parse_llm_response(response, "summary")
            else:
                if not self.parse_json_manually:
                    response = self.call_llm(
                        messages,
                        response_format={
                            "type": "json_schema",
                            "json_schema": {
                                "name": "Summary",
                                "schema": Summary.model_json_schema(),
                            },
                        },
                    )
                else:
                    response = self.call_llm(messages)
                summary_obj = self._parse_llm_response(response, "summary")
            partial_summaries.append(summary_obj)

        self._log_info(
            f"Generated {len(partial_summaries)} partial summaries, merging..."
        )
        # Programmatically merge summaries
        merged_summary = merge_summary_objects(partial_summaries)

        # Apply duplicate removal
        if "duplicate_removal_prompt" in self.config["prompts"]["generate_merge"]:
            self._log_info("Applying duplicate removal")
            duplicate_removal_prompt = self.config["prompts"]["generate_merge"][
                "duplicate_removal_prompt"
            ].format(existing_summary=json.dumps(merged_summary, indent=2))

            messages = [{"role": "user", "content": duplicate_removal_prompt}]

            if self.provider == "ollama" and not self.parse_json_manually:
                response = self.call_llm(messages, format=Summary.model_json_schema())
                final_summary = self._parse_llm_response(response, "summary")
            else:
                if not self.parse_json_manually:
                    response = self.call_llm(
                        messages,
                        response_format={
                            "type": "json_schema",
                            "json_schema": {
                                "name": "Summary",
                                "schema": Summary.model_json_schema(),
                            },
                        },
                    )
                else:
                    response = self.call_llm(messages)
                final_summary = self._parse_llm_response(response, "summary")
        else:
            final_summary = merged_summary

        self._log_summary_stats(final_summary, "generate_merge:")
        return final_summary, partial_summaries

    def cok(
        self, notes: Dict[str, str], entity_name: str = "ENTITY"
    ) -> Tuple[Dict, Dict[str, Union[Dict, CoK]]]:
        """CoK summarization using structured update and add operations."""
        self._log_info(
            f"Starting CoK processing for {len(notes)} notes (entity: {entity_name})"
        )

        current_summary = None
        responses = {}

        for i, (note_id, note_content) in enumerate(notes.items()):
            self._log_info(f"Processing note {i + 1}/{len(notes)} (ID: {note_id})")

            if i == 0:
                # Generate initial summary from first note
                self._log_debug("Creating initial summary from first note")
                user_prompt = self.config["prompts"]["cok"]["initial_prompt"].format(
                    entity_name=entity_name, paragraph=f"P{i + 1}. {note_content}"
                )
                messages = [{"role": "user", "content": user_prompt}]

                if self.provider == "ollama" and not self.parse_json_manually:
                    response = self.call_llm(
                        messages, format=Summary.model_json_schema()
                    )
                    current_summary = self._parse_llm_response(response, "summary")
                else:
                    if not self.parse_json_manually:
                        response = self.call_llm(
                            messages,
                            response_format={
                                "type": "json_schema",
                                "json_schema": {
                                    "name": "Summary",
                                    "schema": Summary.model_json_schema(),
                                },
                            },
                        )
                    else:
                        response = self.call_llm(messages)
                    current_summary = self._parse_llm_response(response, "summary")

                responses[note_id] = current_summary
                self._log_debug(
                    f"Initial summary created with {len(current_summary.get('attributes', {}))} attributes"
                )

            else:
                # Generate new summary from current note
                self._log_debug(
                    "Generating summary for current note and CoK operations"
                )
                new_note_prompt = self.config["prompts"]["cok"][
                    "initial_prompt"
                ].format(entity_name=entity_name, paragraph=f"P{i + 1}. {note_content}")
                messages = [{"role": "user", "content": new_note_prompt}]

                if self.provider == "ollama" and not self.parse_json_manually:
                    response = self.call_llm(
                        messages, format=Summary.model_json_schema()
                    )
                    new_summary = self._parse_llm_response(response, "summary")
                else:
                    if not self.parse_json_manually:
                        response = self.call_llm(
                            messages,
                            response_format={
                                "type": "json_schema",
                                "json_schema": {
                                    "name": "Summary",
                                    "schema": Summary.model_json_schema(),
                                },
                            },
                        )
                    else:
                        response = self.call_llm(messages)
                    new_summary = self._parse_llm_response(response, "summary")

                responses[note_id] = new_summary

                # Generate structured CoK update operations
                existing_keys = list(current_summary["attributes"].keys())
                new_keys = list(new_summary["attributes"].keys())
                relevant_keys = set(existing_keys) & set(new_keys)
                missing_keys = set(new_keys) - set(existing_keys)

                self._log_debug(
                    f"Key analysis - Existing: {len(existing_keys)}, New: {len(new_keys)}, Overlap: {len(relevant_keys)}, Missing: {len(missing_keys)}"
                )

                update_prompt = self.config["prompts"]["cok"][
                    "update_and_add_prompt"
                ].format(
                    entity_name=entity_name,
                    new_summary=json.dumps(new_summary, indent=2),
                    existing_summary=json.dumps(current_summary, indent=2),
                    existing_keys=existing_keys,
                    relevant_keys=relevant_keys,
                    new_keys=new_keys,
                    missing_keys=missing_keys,
                    update_paths=[f'"$.attributes.{k}"' for k in relevant_keys],
                    add_paths=[f'"$.attributes.{k}"' for k in missing_keys],
                )

                messages = [{"role": "user", "content": update_prompt}]

                if self.provider == "ollama" and not self.parse_json_manually:
                    response = self.call_llm(messages, format=CoK.model_json_schema())
                    cok_operations = self._parse_llm_response(response, "cok")
                else:
                    if not self.parse_json_manually:
                        response = self.call_llm(
                            messages,
                            response_format={
                                "type": "json_schema",
                                "json_schema": {
                                    "name": "CoK",
                                    "schema": CoK.model_json_schema(),
                                },
                            },
                        )
                    else:
                        response = self.call_llm(messages)
                    cok_operations = self._parse_llm_response(response, "cok")

                # Log CoK operations details
                update_ops = cok_operations["UPDATE"]
                add_ops = cok_operations["ADD"]

                if self.verbosity == "low":
                    self._log_info(
                        f"CoK operations - {len(update_ops)} updates, {len(add_ops)} additions"
                    )
                elif self.verbosity == "medium":
                    self._log_debug(
                        f"CoK operations - Updates: {list(update_ops.keys()) if update_ops else []}, Adds: {list(add_ops.keys()) if add_ops else []}"
                    )
                elif self.verbosity == "high":
                    self._log_verbose(
                        f"Full CoK operations: {json.dumps(cok_operations, indent=2)}"
                    )

                # Apply CoK operations to update current summary
                self._log_debug("Applying CoK operations to update summary")
                current_summary["attributes"] = apply_cok_operations(
                    current_summary["attributes"], cok_operations
                )

        self._log_summary_stats(current_summary, "CoK:")
        return current_summary, responses

    def summarize(
        self,
        method: str,
        notes: Dict[str, str],
        entity_name: str = "ENTITY",
    ) -> Union[Dict, Tuple]:
        """Main summarization method that uses the configured approach."""
        self._log_info(
            f"Starting {method} summarization with {len(notes)} notes for entity '{entity_name}'"
        )

        if method not in self.config["prompts"]:
            self.logger.error(f"Unknown method: {method}")
            raise ValueError(f"Unknown method: {method}")

        if method == "generate_once":
            result = self.generate_once(notes, entity_name)
        elif method == "generate_update":
            result = self.generate_update(notes, entity_name)
        elif method == "generate_merge":
            result = self.generate_merge(notes, entity_name)
        elif method == "cok":
            result = self.cok(notes, entity_name)
        elif method == "dspy_generate_concat":
            result = self.dspy_generate_concat(notes, entity_name)
        elif method == "dspy_generate_update":
            result = self.dspy_generate_update(notes, entity_name)
        else:
            self.logger.error(f"Unknown method: {method}")
            raise ValueError(f"Unknown method: {method}")

        self._log_info(f"Completed {method} summarization successfully")
        return result
