import os
import sys
from argparse import Namespace
from typing import Dict, Tuple

import torch
from scipy.special import softmax
from sklearn.metrics import classification_report

sys.path.append("..")
from codes.trainer import ModelTrainer
from common.utils import get_timestamp

class ModelEvaluator(ModelTrainer):
    """
    Evaluator class for trained PCG classification models.

    This class extends ModelTrainer to provide evaluation-specific functionality.
    It inherits model setup, data loading, and evaluation methods from ModelTrainer
    while customizing weight loading and result saving for evaluation workflows.

    Unlike ModelTrainer, this class does not support training operations and is
    optimized for inference and metric computation only.
    """

    def __init__(self, args: Namespace, dump_loc: str, device: str) -> None:
        """
        Initialize the model evaluator with configuration and output location.

        Args:
            args: Namespace containing model and evaluation configuration including:
                - modelname: Model architecture name
                - dataset: Dataset identifier
                - Any model-specific parameters (emb_dim, depth, etc.)
            dump_loc: Root directory where evaluation results will be saved.
                A timestamped subdirectory will be created inside
            device: Device string for model inference (e.g., "cuda:0", "cpu")
        """
        self.args = args
        self.args.device = device

        self.device = device
        self.model = None

        timestamp = get_timestamp()
        self.dump_loc = os.path.join(dump_loc, timestamp)

        os.makedirs(self.dump_loc, exist_ok=True)

    def set_weight(self, weight_file):
        """
        Load trained model weights from a checkpoint file.

        This method loads the saved state dictionary, removes any 'module.'
        prefixes (from DataParallel training), and loads the weights into
        the model.

        Args:
            weight_file: Path to the saved model checkpoint (.pth file)

        Raises:
            AssertionError: If self.model is None (model not initialized)
        """
        assert (self.model is not None)

        self.model.to("cpu")

        # Temporal solution.
        state_dict = dict(torch.load(weight_file, map_location="cpu")) # OrderedDict -> dict

        old_keys = list(state_dict.keys())
        for key in old_keys:
            new_key = key.replace("module.", "")
            state_dict[new_key] = state_dict.pop(key)
        self.model.load_state_dict(state_dict)

        self.model.to(self.device)

    def run(self, loader, dump_errors=False) -> Tuple[Dict, str]:
        """
        Execute model evaluation on a dataset and generate classification report.

        Runs the model in evaluation mode, computes comprehensive metrics,
        and generates a formatted classification report using scikit-learn.

        Args:
            loader: DataLoader for the evaluation dataset
            dump_errors: Boolean flag for dumping error samples (currently unused).
                Must be False as error dumping is not implemented (default: False)

        Returns:
            tuple: (result_dict, report) where:
                - result_dict: Dictionary containing evaluation metrics (see _evaluate)
                - report: Formatted classification report string with precision,
                  recall, F1-score for each class

        Raises:
            AssertionError: If dump_errors is True (not implemented)
        """
        assert not dump_errors
        result_dict = self._evaluate(loader, store_sample=False)
        report = classification_report(
            result_dict["y_trues"], 
            softmax(result_dict["y_preds"], axis=1).argmax(axis=1), 
            digits=5, 
            zero_division=0.0
        )
        return result_dict, report    