from src.reliability_eval.common.enums.score_types import QuestionAggregateTypes, SequenceAggregateTypes, TokenScoreTypes
from src.reliability_eval.pipeline.evaluation_pipelines.types import PipelineType
from src.reliability_eval.pipeline.config import ScoreAccessConfig



SCORE_ACCESS_CONFIGS = {
    PipelineType.NLL: ScoreAccessConfig(
        pipeline_name=PipelineType.NLL,
        token_score_type=TokenScoreTypes.NLL,
        sequence_aggregate_type=SequenceAggregateTypes.MEAN,
        question_aggregate_type=QuestionAggregateTypes.MEAN
    ),
    PipelineType.CONFIDENCE: ScoreAccessConfig(
        pipeline_name=PipelineType.CONFIDENCE, 
        token_score_type=TokenScoreTypes.CROSS_ENTROPY,
        sequence_aggregate_type=SequenceAggregateTypes.MEAN,
        question_aggregate_type=QuestionAggregateTypes.MEAN
    ),
    PipelineType.ENTROPY: ScoreAccessConfig(
        pipeline_name=PipelineType.ENTROPY,
        token_score_type=TokenScoreTypes.ENTROPY, 
        sequence_aggregate_type=SequenceAggregateTypes.MEAN,
        question_aggregate_type=QuestionAggregateTypes.MEAN
    ),
    PipelineType.TOPK: ScoreAccessConfig(
        pipeline_name=PipelineType.TOPK,
        token_score_type=TokenScoreTypes.TOP_K,
        sequence_aggregate_type=SequenceAggregateTypes.MEAN,
        question_aggregate_type=QuestionAggregateTypes.MEAN
    ),
    PipelineType.SEMANTIC_ENTROPY: ScoreAccessConfig(
        pipeline_name=PipelineType.SEMANTIC_ENTROPY,
        token_score_type=TokenScoreTypes.SEMANTIC_ENTROPY,
        sequence_aggregate_type=SequenceAggregateTypes.MEAN,
        question_aggregate_type=QuestionAggregateTypes.MEAN
    )
}

ACCURACY_ACCESS_CONFIGS = {
    PipelineType.NLL: ScoreAccessConfig(
        pipeline_name=PipelineType.NLL,
        token_score_type=None,
        sequence_aggregate_type=None,
        question_aggregate_type=QuestionAggregateTypes.ACCURACY
    ),
    PipelineType.CONFIDENCE: ScoreAccessConfig(
        pipeline_name=PipelineType.CONFIDENCE,
        token_score_type=None,
        sequence_aggregate_type=None,
        question_aggregate_type=QuestionAggregateTypes.ACCURACY
    ),
    PipelineType.ENTROPY: ScoreAccessConfig(
        pipeline_name=PipelineType.ENTROPY,
        token_score_type=None,
        sequence_aggregate_type=None,
        question_aggregate_type=QuestionAggregateTypes.ACCURACY
    ),
    PipelineType.TOPK: ScoreAccessConfig(
        pipeline_name=PipelineType.TOPK,
        token_score_type=None,
        sequence_aggregate_type=None,
        question_aggregate_type=QuestionAggregateTypes.ACCURACY
    ),
    PipelineType.SEMANTIC_ENTROPY: ScoreAccessConfig(
        pipeline_name=PipelineType.SEMANTIC_ENTROPY,
        token_score_type=None,
        sequence_aggregate_type=None,
        question_aggregate_type=QuestionAggregateTypes.ACCURACY
    )
}

TOKEN_ID_ACCESS_CONFIGS = {
    PipelineType.NLL: ScoreAccessConfig(
        pipeline_name=PipelineType.NLL,
        token_score_type=TokenScoreTypes.NLL,
        sequence_aggregate_type=None,
        question_aggregate_type=None
    ),
    PipelineType.CONFIDENCE: ScoreAccessConfig(
        pipeline_name=PipelineType.CONFIDENCE,
        token_score_type=TokenScoreTypes.CROSS_ENTROPY,
        sequence_aggregate_type=None,
        question_aggregate_type=None
    ),
    PipelineType.ENTROPY: ScoreAccessConfig(
        pipeline_name=PipelineType.ENTROPY,
        token_score_type=TokenScoreTypes.ENTROPY,
        sequence_aggregate_type=None,
        question_aggregate_type=None
    ),
    PipelineType.TOPK: ScoreAccessConfig(
        pipeline_name=PipelineType.TOPK,
        token_score_type=TokenScoreTypes.TOP_K,
        sequence_aggregate_type=None,
        question_aggregate_type=None
    ),
    PipelineType.SEMANTIC_ENTROPY: ScoreAccessConfig(
        pipeline_name=PipelineType.SEMANTIC_ENTROPY,
        token_score_type=TokenScoreTypes.SEMANTIC_ENTROPY,
        sequence_aggregate_type=None,
        question_aggregate_type=None
    )
}
