#!/usr/bin/env python3
"""
AugmentPipeline: An integrated augmentation pipeline for QuestionAugment, CoTAugment, CaptionAugment and VisualizeAugment.
"""

import json
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
import argparse

from .QuestionAugment import QuestionAugment
from .CoTAugment import CoTAugment
from .CaptionAugment import CaptionAugment
from .VisualizeAugment import VisualizeAugment
from ...utils.model_urls import get_base_url_for_model
from ...utils.augment_config import AugmentConfig


class AugmentPipeline:
    """Integrated pipeline that combines multiple augmentation steps."""
    
    def __init__(
        self,
        api_key: Optional[str] = None,
        model: Optional[str] = None,
        question_model: Optional[str] = None,
        cot_model: Optional[str] = None,
        caption_model: Optional[str] = None,
        visualize_model: Optional[str] = None,
        base_url: Optional[str] = None,
        delay: Optional[float] = None,
        max_retries: Optional[int] = None,
        num_question_versions: int = 3,
        num_cot_versions: int = 3,
        num_caption_versions: int = 2,
        num_visualize_versions: int = 2,
    ):
        """
        Initialize the augmentation pipeline.
        
        Args:
            api_key: API Key.
            model: Default model name.
            question_model: Model for question augmentation (if None, use `model`).
            cot_model: Model for CoT (reasoning process) augmentation (if None, use `model`).
            caption_model: Model for caption/description augmentation (if None, use `model`).
            visualize_model: Model for generating visual QA pairs (if None, use `model`).
            base_url: Base API URL.
            delay: Delay (in seconds) between API calls.
            max_retries: Maximum number of retries.
            num_question_versions: Number of question versions to generate.
            num_cot_versions: Number of reasoning process (CoT) versions to generate.
            num_caption_versions: Number of caption/description versions to generate.
            num_visualize_versions: Number of visual QA pair versions to generate.
        """
        question_model = question_model or model
        cot_model = cot_model or model
        caption_model = caption_model or model
        visualize_model = visualize_model or model
        
        # Automatically obtain base_url according to model name
        question_base_url = get_base_url_for_model(question_model) if question_model else None
        cot_base_url = get_base_url_for_model(cot_model) if cot_model else None
        caption_base_url = get_base_url_for_model(caption_model) if caption_model else None
        visualize_base_url = get_base_url_for_model(visualize_model) if visualize_model else None
        
        self.question_augment = QuestionAugment(
            api_key=api_key,
            model=question_model,
            base_url=question_base_url or base_url,
            max_retries=max_retries,
        )
        
        self.cot_augment = CoTAugment(
            api_key=api_key,
            model=cot_model,
            base_url=cot_base_url or base_url,
            max_retries=max_retries,
        )
        
        self.caption_augment = CaptionAugment(
            api_key=api_key,
            model=caption_model,
            base_url=caption_base_url or base_url,
            max_retries=max_retries,
        )
        
        self.visualize_augment = VisualizeAugment(
            api_key=api_key,
            model=visualize_model,
            base_url=visualize_base_url or base_url,
            max_retries=max_retries,
        )
        
        self.delay = delay
        self.num_question_versions = num_question_versions
        self.num_cot_versions = num_cot_versions
        self.num_caption_versions = num_caption_versions
        self.num_visualize_versions = num_visualize_versions
        # Token usage statistics
        self.total_token_usage = {
            "prompt_tokens": 0,
            "completion_tokens": 0,
            "total_tokens": 0,
        }
    
    @staticmethod
    def _sleep(delay: Optional[float]) -> None:
        """Sleep for the specified delay (in seconds)."""
        if delay and delay > 0:
            time.sleep(delay)
    
    def _accumulate_token_usage(self, usage_dict: Optional[Dict[str, int]]) -> None:
        """
        Accumulate token usage statistics.
        
        Args:
            usage_dict: A dictionary that contains token usage information.
        """
        if usage_dict:
            self.total_token_usage["prompt_tokens"] += usage_dict.get("prompt_tokens", 0)
            self.total_token_usage["completion_tokens"] += usage_dict.get("completion_tokens", 0)
            self.total_token_usage["total_tokens"] += usage_dict.get("total_tokens", 0)
    
    def augment_item(
        self,
        item: Dict[str, Any],
        index: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Augment a single data item and generate multiple augmented versions.
        
        Args:
            item: Original data item (containing fields such as `problem`, `cot`, `answer`, `perception`, etc.).
            index: Optional index (used only for logging).
        
        Returns:
            The augmented data item, which contains all generated versions, e.g.:
            {
                "problem": {"Q1": "...", "Q2": "...", "Q3": "..."},
                "cot": {"CoT1": "...", "CoT2": "...", "CoT3": "..."},
                "answer": "8",  # the answer is kept unchanged
                "perception": {"C1": "...", "C2": "..."},
                "visualize": {"V1": {"question": "...", "answer": "..."}, "V2": {...}},
                ...
            }
        """
        # Create a shallow copy for the augmented data item
        augmented_item = item.copy()

        # Augment question
        if "problem" in item:
            question_versions = self.question_augment.augment(
                item["problem"], index, self.num_question_versions
            )
            augmented_item["problem"] = question_versions
            self._accumulate_token_usage(self.question_augment.last_token_usage)
            self._sleep(self.delay)
        else:
            augmented_item["problem"] = {"Q1": ""}
        
        # Augment the reasoning process (CoT)
        if "cot" in item:
            cot_versions = self.cot_augment.augment(
                item["cot"], index, self.num_cot_versions
            )
            augmented_item["cot"] = cot_versions
            self._accumulate_token_usage(self.cot_augment.last_token_usage)
            self._sleep(self.delay)
        else:
            augmented_item["cot"] = {"CoT1": ""}
        
        # The answer is kept unchanged and is not augmented.
        # The `answer` field directly keeps its original value.
        
        # Augment caption/description
        original_caption = item.get("perception", "")
        if "perception" in item:
            caption_versions = self.caption_augment.augment(
                item["perception"], index, self.num_caption_versions
            )
            augmented_item["perception"] = caption_versions
            self._accumulate_token_usage(self.caption_augment.last_token_usage)
            self._sleep(self.delay)
        else:
            augmented_item["perception"] = {"C1": ""}
        
        # Generate visual QA pairs (synchronized with caption augmentation, using the original caption).
        # `caption` is required; if it does not exist or is empty, an error will be raised.
        plotting_code = item.get("plotting_code")
        if plotting_code is not None:
            visualize_versions = self.visualize_augment.augment(
                plotting_code,
                caption=original_caption,
                index=index,
                num_versions=self.num_visualize_versions,
            )
            augmented_item["visualize"] = visualize_versions
            self._accumulate_token_usage(self.visualize_augment.last_token_usage)
            self._sleep(self.delay)
        else:
            augmented_item["visualize"] = {"V1": {"question": "", "answer": ""}}
        
        # Remove plotting code from the final output
        if "plotting_code" in augmented_item:
            del augmented_item["plotting_code"]
        return augmented_item
    
    def augment_file(
        self,
        input_file: str,
        output_file: Optional[str] = None,
        start_index: int = 0,
        end_index: Optional[int] = None,
    ) -> List[Dict[str, Any]]:
        """
        Augment all data items in a JSON file.
        
        Args:
            input_file: Path to the input JSON file.
            output_file: Path to the output JSON file (if None, it will be generated automatically).
            start_index: Start index (for batch processing).
            end_index: End index (if None, process until the end of file).
        
        Returns:
            The list of augmented data items.
        """
        # Read input file
        input_path = Path(input_file)
        if not input_path.exists():
            raise FileNotFoundError(f"Input file does not exist: {input_file}")
        
        print(f"Reading file: {input_file}")
        with open(input_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        
        if not isinstance(data, list):
            raise ValueError("The input file must contain a JSON array.")
        
        # Determine processing range
        total_items = len(data)
        end_idx = end_index if end_index is not None else total_items
        end_idx = min(end_idx, total_items)
        print(f"Total {total_items} items, processing indices from {start_index} to {end_idx - 1}.")
        
        # Reset token statistics
        self.total_token_usage = {
            "prompt_tokens": 0,
            "completion_tokens": 0,
            "total_tokens": 0,
        }
        
        # Augment data items
        augmented_data = []
        for i in range(start_index, end_idx):
            item = data[i]
            item_index = item.get("index", i)
            print(f"\nProcessing sample {item_index} ({i + 1}/{end_idx - start_index})...")
            
            augmented_item = self.augment_item(item, index=item_index)
            augmented_data.append(augmented_item)
            
            # Log number of generated versions
            num_q = len(augmented_item.get("problem", {}))
            num_cot = len(augmented_item.get("cot", {}))
            num_c = len(augmented_item.get("perception", {}))
            num_v = len(augmented_item.get("visualize", {}))
            print(
                f"[Sample {item_index}] generated {num_q} question versions, "
                f"{num_cot} CoT versions, {num_c} caption versions, {num_v} visual QA pair versions."
            )
        
        # Save results
        if output_file is None:
            # Automatically generate output file name
            output_path = input_path.parent / f"{input_path.stem}_augmented{input_path.suffix}"
        else:
            output_path = Path(output_file)
        
        print(f"\nSaving results to: {output_path}")
        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(augmented_data, f, ensure_ascii=False, indent=2)
        
        # Print total token usage statistics
        print(f"\n{'='*60}")
        print("Token usage statistics (total):")
        print(f"  Prompt Tokens:     {self.total_token_usage['prompt_tokens']:,}")
        print(f"  Completion Tokens: {self.total_token_usage['completion_tokens']:,}")
        print(f"  Total Tokens:      {self.total_token_usage['total_tokens']:,}")
        print(f"{'='*60}")
        
        print(f"Done! Augmented {len(augmented_data)} items.")
        
        return augmented_data


def main():
    """Entry point for the CLI."""
    parser = argparse.ArgumentParser(
        description="Augmentation pipeline: question augmentation -> CoT augmentation -> caption augmentation",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument(
        "--config",
        type=Path,
        default=None,
        help="Path to the config file (JSON format, optional)",
    )
    parser.add_argument(
        "--input",
        type=Path,
        required=True,
        help="Path to the input JSON file",
    )
    parser.add_argument(
        "--output",
        type=Path,
        default=None,
        help="Output JSON file path (if None, it will be generated automatically)",
    )
    parser.add_argument(
        "--start_index",
        type=int,
        default=0,
        help="Start index (default 0)",
    )
    parser.add_argument(
        "--end_index",
        type=int,
        default=None,
        help="End index (if None, process until the end of file)",
    )
    
    args = parser.parse_args()
    
    # Load config (if a config file is provided)
    if args.config:
        AugmentConfig.load_from_file(args.config)
    
    # Create the augmentation pipeline (config will be loaded lazily on first access)
    pipeline = AugmentPipeline(
        api_key=AugmentConfig.API_KEY,
        model=AugmentConfig.MODEL,
        question_model=AugmentConfig.QUESTION_MODEL,
        cot_model=AugmentConfig.COT_MODEL,
        caption_model=AugmentConfig.CAPTION_MODEL,
        visualize_model=AugmentConfig.VISUALIZE_MODEL,
        base_url=AugmentConfig.BASE_URL,
        delay=AugmentConfig.DELAY,
        max_retries=AugmentConfig.MAX_RETRIES,
        num_question_versions=AugmentConfig.NUM_QUESTION_VERSIONS,
        num_cot_versions=AugmentConfig.NUM_COT_VERSIONS,
        num_caption_versions=AugmentConfig.NUM_CAPTION_VERSIONS,
        num_visualize_versions=AugmentConfig.NUM_VISUALIZE_VERSIONS,
    )
    
    # Process file
    pipeline.augment_file(
        input_file=str(args.input),
        output_file=str(args.output) if args.output else None,
        start_index=args.start_index,
        end_index=args.end_index,
    )


if __name__ == "__main__":
    main()

