from langchain_community.llms import Ollama
from collections import defaultdict
import json
import os
import statistics
import sys
import time
from tqdm import tqdm

from config.models import MODEL_CONFIGS
from config.domains import ATTRIBUTION_DOMAINS
from config.prompts import ATTRIBUTION_PROMPTS


class AttributionClassifier:
    def __init__(
        self,
        model_name,
        attribution_domain,
        suffix,
        prompt,
        base_path,
        output_path,
        spec_path,
        log_dir,
        shot,
        start_idx=1,
        end_idx=1,
        runs=1
    ):
        self.model_name = model_name
        self.attribution_domain = attribution_domain
        self.suffix = suffix
        self.shot = shot
        self.start_idx = start_idx
        self.end_idx = end_idx
        self.runs = runs
        self.prompt = prompt
        self.log_dir = log_dir

        self.folder_base = base_path
        self.output_path = output_path
        self.input_json_template = os.path.join(self.folder_base, spec_path)

        self.formatted_name = self.model_name.replace(":", "_") + self.suffix
        self.json_label = f"{self.formatted_name}_label"
        self.llm = Ollama(model=self.model_name)

    def classify_single(self, input_json, output_json):
        with open(input_json, 'r') as file:
            data = json.load(file)

        for element in tqdm(data):
            prompt_ = self.prompt + element["llama3.3_40b_label"]
            label = self.llm.predict(prompt_)
            label = label.strip().replace('\n', '')
            element[self.json_label] = label

        with open(output_json, 'w') as file:
            json.dump(data, file, indent=2)

    def calculate_accuracy(self, output_json, mismatch_json):
        with open(output_json, 'r') as file:
            data = json.load(file)

        class_counts = defaultdict(lambda: {"correct": 0, "total": 0})
        confusion_matrix = defaultdict(lambda: defaultdict(int))
        mismatched_elements = []

        for element in data:
            true_label = element['label']
            predicted_label = element.get(self.json_label)

            confusion_matrix[true_label][predicted_label] += 1
            class_counts[true_label]["total"] += 1

            if true_label == predicted_label:
                class_counts[true_label]["correct"] += 1
            else:
                mismatched_elements.append(element)

        class_accuracies = {
            label: stats["correct"] / stats["total"] if stats["total"] > 0 else 0
            for label, stats in class_counts.items()
        }

        total_correct = sum(stats["correct"] for stats in class_counts.values())
        total_samples = sum(stats["total"] for stats in class_counts.values())
        overall_accuracy = total_correct / total_samples if total_samples > 0 else 0

        with open(mismatch_json, 'w') as file:
            json.dump(mismatched_elements, file, indent=2)

        return overall_accuracy, class_accuracies, confusion_matrix

    @staticmethod
    def format_confusion_matrix(confusion_matrix):
        labels = sorted(set(confusion_matrix.keys()) | 
                        set(label for row in confusion_matrix.values() for label in row.keys()))
        
        col_widths = {label: max(len(str(confusion_matrix[row][label])) for row in labels) 
                      for label in labels}
        col_widths['True/Pred'] = max(len(label) for label in labels)
        
        header = "True/Pred" + " " * (col_widths['True/Pred'] - 8)
        for label in labels:
            header += f" | {label:<{col_widths[label]}}"
        
        separator = "-" * (len(header) + 1)
        
        rows = [header, separator]
        for true_label in labels:
            row = f"{true_label:<{col_widths['True/Pred']}}"
            for pred_label in labels:
                row += f" | {confusion_matrix[true_label][pred_label]:<{col_widths[pred_label]}}"
            rows.append(row)
        
        return "\n".join(rows)

    def run_single_classification(self, input_json, output_json_template, mismatch_json_template):
        accuracies = []
        
        for run in range(1, self.runs + 1):
            print(f"Processing: {input_json}, Run {run}...")
            output_json = output_json_template.format(run)
            mismatch_json = mismatch_json_template.format(run)

            self.classify_single(input_json, output_json)
            overall_accuracy, class_accuracies, confusion_matrix = self.calculate_accuracy(output_json, mismatch_json)

            print(f"Run {run} - Overall Accuracy: {overall_accuracy:.2%}")
            print(f"Run {run} - Class-wise Accuracies: {class_accuracies}")
            print(f"Run {run} - Confusion Matrix:")
            print(self.format_confusion_matrix(confusion_matrix))
            accuracies.append(overall_accuracy)

        average_accuracy = sum(accuracies) / len(accuracies)
        std_accuracy = statistics.stdev(accuracies) if len(accuracies) > 1 else 0

        return average_accuracy, std_accuracy, accuracies

    @staticmethod
    def ensure_dir_exists(directory):
        if not os.path.exists(directory):
            os.makedirs(directory, exist_ok=True)

    def process_all_files(self):
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        domain_str = self.attribution_domain
        model_str = self.model_name
        shot_type = self.shot

        self.ensure_dir_exists(self.log_dir)
        log_filename = os.path.join(self.log_dir, f"{domain_str}_{model_str}_{shot_type}_{timestamp}.json")

        with open(log_filename, "w", encoding="utf-8") as log_file:
            original_stdout = sys.stdout
            sys.stdout = log_file

            try:
                for i in range(self.start_idx, self.end_idx + 1):
                    input_json = self.input_json_template.format(i)
                    start_time = time.time()
                    
                    if not os.path.exists(input_json):
                        print(f"Skipping {input_json}: File not found")
                        continue

                    base_name = os.path.basename(input_json).replace(".json", "")
                    
                    output_dir = os.path.join(
                        self.output_path,
                        f"{self.attribution_domain}/{self.shot}_{self.formatted_name}_result"
                    )
                    self.ensure_dir_exists(output_dir)
                    
                    output_json_template = os.path.join(
                        output_dir,
                        f"{base_name}_run_{{}}.json"
                    )
                    
                    mismatch_dir = os.path.join(
                        self.output_path,
                        f"{self.attribution_domain}/{self.shot}_{self.formatted_name}_mismatch"
                    )
                    self.ensure_dir_exists(mismatch_dir)
                    
                    mismatch_json_template = os.path.join(
                        mismatch_dir,
                        f"{base_name}_mismatch_run_{{}}.json"
                    )

                    avg_acc, std_acc, acc_list = self.run_single_classification(
                        input_json, output_json_template, mismatch_json_template
                    )
                    
                    end_time = time.time()
                    total_time = end_time - start_time
                    
                    print(f"\nSummary for {input_json}:")
                    print(f"  - Average Accuracy: {avg_acc:.2%}")
                    print(f"  - Standard Deviation: {std_acc:.6f}")
                    print(f"  - Per Run Accuracy: {', '.join(f'{acc:.2%}' for acc in acc_list)}")
                    print(f"  - Total execution time: {total_time / 60:.2f} minutes")

            finally:
                sys.stdout = original_stdout

        print(f"Results saved in {log_filename}")

    @classmethod
    def run_classification(
        cls,
        model_name: str,
        attribution_domain: str,
        shot: str,
        start_idx: int = 1,
        end_idx: int = 1,
        runs: int = 1,
        base_path: str = None,
        output_path: str = None,
        log_dir: str = None
    ):
        print(model_name)
        if model_name not in MODEL_CONFIGS:
            raise ValueError(f"Unknown model: {model_name}. Available models: {list(MODEL_CONFIGS.keys())}")
        if attribution_domain not in ATTRIBUTION_DOMAINS:
            raise ValueError(f"Unknown domain: {attribution_domain}. Available domains: {list(ATTRIBUTION_DOMAINS.keys())}")
        if shot not in ATTRIBUTION_PROMPTS[attribution_domain]:
            raise ValueError(f"Unknown shot type: {shot} for domain {attribution_domain}")

        model_config = MODEL_CONFIGS[model_name]
        domain_config = ATTRIBUTION_DOMAINS[attribution_domain]
        prompt = ATTRIBUTION_PROMPTS[attribution_domain][shot]

        classifier = cls(
            model_name=model_config["name"],
            attribution_domain=domain_config["name"],
            suffix=model_config["suffix"],
            prompt=prompt,
            base_path=base_path,
            output_path=output_path,
            log_dir=log_dir,
            spec_path=domain_config["spec_path"],
            shot=shot,
            start_idx=start_idx,
            end_idx=end_idx,
            runs=runs
        )

        classifier.process_all_files()
        return classifier
