import json
import re
from typing import Dict, List, Optional
from langchain.llms import Ollama
from config.prompts import ATTRIBUTION_RATING_PROMPTS
from pathlib import Path
from tqdm import tqdm
from config.domains import *
import openai


class RatingCriteria:
    """Evaluation module for attributional style scoring"""

    def __init__(self, model_name: str, gpt: bool):
        self.model_name = model_name
        self.llm = Ollama(model=model_name)
        self.gpt = gpt

        # ❗ IMPORTANT: Replace this API key with a secure one or use environment variables.
        self.OPENAI_API_KEY = "sk-REDACTED"
        self.model_name = "gpt-4o-2024-11-20"

    def parse_rating(self, response: str) -> int:
        """Map responses like 'A', 'B', 'C' to numeric values."""
        return {"A": 1, "B": 2, "C": 3}.get(response, 2)

    def parse_rating_from_5(self, response: str) -> int:
        """Parse ratings from 1–5 scale."""
        return {"1": 1, "2": 2, "3": 3, "4": 4, "5": 5}.get(response, None)

    def _complete(self, prompt):
        openai.api_key = self.OPENAI_API_KEY
        response = openai.ChatCompletion.create(
            model=self.model_name,
            messages=[
                {"role": "system", "content": "You are a strict evaluator. Follow the instructions exactly."},
                {"role": "user", "content": prompt}
            ],
            temperature=0,
            max_tokens=3,
        )
        score = response["choices"][0]["message"]["content"].strip()
        return score

    def _process_file(self, input_file: Path, output_file: Path, criteria: Optional[List[str]] = None) -> Optional[Path]:
        """Evaluate all reframed outputs in a single input JSON file."""
        try:
            with open(input_file, 'r', encoding='utf-8') as file:
                data = json.load(file)

            data_list = []

            for item in tqdm(data, desc=f"Processing {criteria}"):
                original_post = item.get("utterance", "")

                reframed_items = {
                    key: item[key] for key in item.keys()
                    if key.startswith("reframed_utterance_")
                }

                result = item.copy()

                for reframed_key, reframed_post in reframed_items.items():
                    for criterion in criteria:
                        prompt = ATTRIBUTION_RATING_PROMPTS[criterion] + reframed_post
                        response = self._complete(prompt) if self.gpt else self.llm.predict(prompt)
                        result[f"{reframed_key}_{criterion}"] = self.parse_rating_from_5(response)

                data_list.append(result)

            # Save under a parallel directory if GPT mode is used
            if self.gpt:
                output_file = Path(output_file)
                parent = output_file.parent
                grandparent = parent.parent
                new_parent = grandparent / (parent.name + "_gpt")
                output_file = new_parent / output_file.name

            output_file.parent.mkdir(parents=True, exist_ok=True)

            with open(output_file, 'w', encoding='utf-8') as file:
                json.dump(data_list, file, indent=2, ensure_ascii=False)

            return output_file

        except FileNotFoundError:
            print(f"File not found: {input_file}")
            return None
        except json.JSONDecodeError:
            print(f"JSON decode error: {input_file}")
            return None
        except Exception as e:
            print(f"Error processing file: {str(e)}")
            return None

    def process_all_files(self, input_dir: Path, output_dir: Path, file_numbers: Optional[List[int]] = None, criteria: Optional[List[str]] = None) -> List[Path]:
        """Evaluate all eligible files in a directory."""
        processed_files = []
        input_dir = Path(input_dir)
        output_dir = Path(output_dir)

        json_files = list(input_dir.glob("*.json"))

        if file_numbers is not None:
            filtered_files = []
            for file in json_files:
                try:
                    file_number = int(file.stem.split('_')[-3])
                    if file_number in file_numbers:
                        filtered_files.append(file)
                except ValueError:
                    pass
            json_files = filtered_files

        for file in tqdm(json_files, desc="Processing files"):
            try:
                output_file = output_dir / file.name
                output_file = self._process_file(file, output_file, criteria)

                if output_file:
                    processed_files.append(output_file)
                    print(f"Successfully processed: {file} -> {output_file}")
            except Exception as e:
                print(f"Error processing {file}: {str(e)}")
                continue

        return processed_files


def process_model_tasks(rating_criteria, model, domain_configs, file_numbers=None):
    """Batch run across multiple attribution domains."""
    for domain in domain_configs:
        domain_key_ = re.match(r'^[a-z]+', domain.lower())
        domain_key = domain_key_.group()

        config_var = f"{domain}_{model}".replace(":", "_").replace("-", "_").replace(".", "_")
        config = globals()[config_var]

        rating_criteria.process_all_files(
            input_dir=config[f"{domain_key}_path"],
            output_dir=config[f"{domain_key}_result_path"],
            file_numbers=file_numbers,
            criteria=config[f"{domain_key}_criterion_configs"]
        )
