"""Core module for attributional style transformation."""

import json
import os
from pathlib import Path
from typing import Dict, List, Optional
from enum import Enum
from tqdm import tqdm
from langchain.llms import Ollama
from datetime import datetime

from config.prompts import SHIFT_PROMPTS
from config.domains import ATTRIBUTION_DOMAINS

class AttributionTransfer:
    """Attributional Style Transformer"""
    
    def __init__(self, model_name: str, attribution_type: str, num_samples: int = 5, temperature: float = 0.7):
        """
        Initialize the transformer.

        Args:
            model_name: Name of the model to use.
            attribution_type: Attribution dimension to be transformed (e.g., IAS→EAS).
            num_samples: Number of samples to generate per input.
            temperature: Sampling temperature. Higher values yield more diverse outputs,
                         lower values produce more conservative outputs. Recommended range: 0.1–1.0
        """
        self.model_name = model_name
        self.attribution_type = attribution_type
        self.llm = Ollama(
            model=model_name,
            temperature=temperature,
            top_p=0.9,
            top_k=40,
            repeat_penalty=1.1
        )
        self.prompt = SHIFT_PROMPTS[attribution_type]
        domain_config = ATTRIBUTION_DOMAINS[attribution_type]
        self.num_samples = num_samples
        self.input_path = domain_config["input_path"]

    def _process_post(self, post: str) -> List[str]:
        """
        Generate multiple reframed outputs for a single post.

        Args:
            post: Input text (utterance).

        Returns:
            A list of generated reframed versions.
        """
        full_prompt = self.prompt + post
        results = []

        for _ in range(self.num_samples):
            res = self.llm.predict(full_prompt)
            result = res.split("Post:", 1)[-1].strip()
            results.append(result)

        return results
        
    def _save_results(self, data_list: List[Dict], output_file: Path):
        """
        Save transformed results to a JSON file.

        Args:
            data_list: List of result dictionaries.
            output_file: Output file path.
        """
        os.makedirs(output_file.parent, exist_ok=True)
        with open(output_file, 'w', encoding='utf-8') as json_file:
            json.dump(data_list, json_file, indent=4, ensure_ascii=False)
            
    def _process_file(self, input_file: Path, output_path: Path) -> Optional[Path]:
        """
        Process a single input file and generate outputs.

        Args:
            input_file: Input JSON file.
            output_path: Directory to save results.

        Returns:
            Path to the output file if successful, None otherwise.
        """
        try:
            with open(input_file, 'r', encoding='utf-8') as file:
                data = json.load(file)

            data_list = []

            for item in tqdm(data, desc=f"Processing {self.attribution_type}"):
                post = item.get("utterance", "")
                if not post:
                    continue

                result_strs = self._process_post(post)

                result_entry = {
                    "utterance": post,
                    "id": item.get("id", ""),
                    "label": item.get("label", ""),
                }

                for i, result_str in enumerate(result_strs, 1):
                    result_entry[f"reframed_utterance_{i}"] = result_str

                data_list.append(result_entry)

            model_name_safe = self.model_name.replace(":", "_").replace("-", "_")  
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_file = Path(output_path) / self.attribution_type / model_name_safe / f"{input_file.stem}_{timestamp}{input_file.suffix}"
            self._save_results(data_list, output_file)
            return output_file

        except Exception as e:
            print(f"Error processing file {input_file}: {str(e)}")
            return None
            
    def process_all_files(self, input_dir: Path, output_dir: Path, file_numbers: Optional[List[int]] = None) -> List[Path]:
        """
        Process all files within a directory and apply attributional style transformation.

        Args:
            input_dir: Directory containing input JSON files.
            output_dir: Directory to save output files.
            file_numbers: Optional list of specific file index numbers to process.

        Returns:
            List of output file paths that were successfully generated.
        """
        processed_files = []
        input_dir = Path(input_dir)
        output_dir = Path(output_dir)

        input_path = input_dir / self.input_path
        json_files = list(input_path.glob("*.json"))

        if file_numbers is not None:
            filtered_files = []
            for file in json_files:
                try:
                    file_number = int(file.stem.split('_')[-1])
                    if file_number in file_numbers:
                        filtered_files.append(file)
                except ValueError:
                    continue
            json_files = filtered_files
            print(f"Processing specific files: {file_numbers}")

        for input_file in tqdm(json_files, desc="Processing files"):
            try:
                output_file = self._process_file(input_file, output_dir)
                if output_file:
                    processed_files.append(output_file)
                    print(f"Successfully processed: {input_file} -> {output_file}")
            except Exception as e:
                print(f"Error processing {input_file}: {str(e)}")
                continue

        return processed_files
