import os
import pickle
from typing import Dict
from argparse import Namespace

import numpy as np
from sklearn.metrics import (
    classification_report,
    confusion_matrix
)
from scipy.special import softmax

from codes.evaluator import ModelEvaluator

def get_confmat_text(ytrues, ypreds):
    """
    Generate a formatted confusion matrix string from predictions.

    Args:
        ytrues: Array of true labels
        ypreds: Array of raw prediction logits (before softmax)

    Returns:
        str: Formatted confusion matrix as a string with comma-separated values
    """
    confmat = confusion_matrix(
        ytrues,
        softmax(ypreds, axis=1).argmax(axis=1)
    )
    return np.array2string(confmat, separator=', ')

class ReportManager:
    """
    Manager for building formatted evaluation reports.

    This class accumulates evaluation metrics and results into a formatted
    text report with consistent structure and spacing.
    """

    def __init__(self, eval_target: str):
        """
        Initialize the report manager with evaluation target information.

        Args:
            eval_target: Path to the evaluated model checkpoint
        """
        self.report = "\n\nEVAL TARGET\n" + eval_target + "\n\n"

    def add_row(self, key, content, n_rep: int=2):
        """
        Add a key-value row to the report with customizable spacing.

        Args:
            key: Label for the metric or result
            content: Value or description to display
            n_rep: Number of newlines to add after the row for spacing (default: 2)
        """
        self.report += f"{key}: {content}"

        for _ in range(n_rep):
            self.report += "\n"

    def get_report(self):
        """
        Retrieve the complete formatted report.

        Returns:
            str: Formatted evaluation report with all added rows
        """
        return self.report.strip()

def run_eval(
    eval_target: str,
    device: str,
    dump_loc: str,
    multiseed_run: bool=True,
    dump_errors: bool=False,
    overwrite_params: dict=None
):
    """
    Execute comprehensive model evaluation on validation and test sets.

    This function orchestrates the complete evaluation pipeline including:
    1. Loading saved model parameters and weights
    2. Optionally overriding parameters (useful for multi-seed evaluation)
    3. Initializing evaluator with proper configuration
    4. Running evaluation on both validation and test datasets
    5. Computing classification reports and confusion matrices
    6. Generating and saving a formatted evaluation report

    The function automatically selects the appropriate test set:
    - For synthetic datasets: uses validation set as test set
    - For real datasets: uses dedicated test set

    Args:
        eval_target: Path to the trained model checkpoint directory
            containing 'params.pkl' and 'net.pth'
        device: Device string for model inference (e.g., "cuda:0", "cpu")
        dump_loc: Root directory where evaluation results will be saved
        multiseed_run: Boolean flag indicating multi-seed evaluation mode.
            If True, results are saved to dump_loc/multirun/eval/ (default: True)
        dump_errors: Boolean flag to save misclassified samples for error analysis.
            Currently not fully implemented (default: False)
        overwrite_params: Optional dictionary of parameters to override from
            saved configuration. Common overrides:
            - seed: Random seed
            - dataset: Dataset name
            - data_lim: Data limit for testing
            - load_demos: Whether to load demographic information

    Returns:
        tuple: (val_result, test_result) where each is a dictionary containing:
            - y_trues: Array of true labels
            - y_preds: Array of prediction logits
            - loss: Average loss
            - f1score, Recall, Precision: Classification metrics
            - AUROC, AUPRC: Area under curves
            - confusion_matrix: Formatted confusion matrix string
            - demographics: Optional demographic information (if load_demos=True)

    Side Effects:
        - Creates evaluation directory structure
        - Saves 'report.txt' with detailed metrics to dump_loc
    """
    report_manager = ReportManager(eval_target)

    if multiseed_run:
        dump_loc = os.path.join(dump_loc, "multirun", "eval")

    # Settings
    param_file = os.path.join(eval_target, "params.pkl")
    weightfile = os.path.join(eval_target, "net.pth")

    with open(param_file, "rb") as fp:
        params = pickle.load(fp)
    params.data_lim = None
    if overwrite_params is not None:
        params = vars(params)
        for key, val in overwrite_params.items():
            params[key] = val
        params = Namespace(**params)

    report_manager.add_row("Model", params.modelname)
    report_manager.add_row("Parameters", str(params))
    
    # Evaluator
    evaluator = ModelEvaluator(
        params, dump_loc, device)
    evaluator.set_model()
    evaluator.set_lossfunc()
    evaluator.set_weight(weightfile)

    loader = evaluator.prepare_dataloader(
        "val", is_train=False)
    val_result, val_report = evaluator.run(loader)
    val_confmat = get_confmat_text(
        val_result["y_trues"], val_result["y_preds"]
    )

    if params.dataset == "syn":
        test_set = "val"
    else:
        test_set = "test"
    loader = evaluator.prepare_dataloader(
        test_set, is_train=False)
    test_result, test_report = evaluator.run(
        loader, dump_errors=dump_errors)
    test_confmat = get_confmat_text(
        test_result["y_trues"], test_result["y_preds"]
    )

    report_manager.add_row(
        "Validation set result\n", val_report)
    report_manager.add_row(
        "Validation set confusion matrix\n", 
        val_confmat
    )
    report_manager.add_row(
        "Test set result\n", test_report)
    report_manager.add_row(
        "Test set confusion matrix\n", 
        test_confmat
    )
    report = report_manager.get_report()
    with open(os.path.join(dump_loc, "report.txt"), "w") as f:
        f.write(report)
    
    return val_result, test_result
