"""
"""


import os
import numpy as np
from typing import (
    Any,
    Dict,
    List,
    Text,
    Literal
)
try:
    import ujson as json
except ImportError:
    import json
import click
from matplotlib import pyplot as plt
from src.utils.transforms import (
    _inverse_sigmoid_unli,
    _beta_sigmoid,
    _discretize_gaussian,
    _BETA_POSITIVE_,
    _BETA_NEGATIVE_
)


plt.style.use('ggplot')


def bts(score: float) -> float:
    """ """
    score = 100 * score - 50
    if score > 0:
        return _beta_sigmoid(score, _BETA_POSITIVE_).item()
    else:
        return _beta_sigmoid(score, _BETA_NEGATIVE_).item()


@click.command()
@click.option(
    "--input-path",
    type=click.Path(exists=True),
    help="Path to the input file",
)
def main(
    input_path
):
    """ """
    
    with open(input_path, 'r', encoding='utf-8') as file_:
        data = [json.loads(line.strip()) for line in file_.readlines()]
        
    def _get_score(
        item: Dict[Text, Any],
        multiplicity_threshold: int,
        config: Literal["coarse", "fine"],
    ) -> float:
        backoffs = sorted(item['backoffs'], key=lambda x: x['multiplicity'], reverse=False)
        current_score = None
        for bck in backoffs:
            if bck['multiplicity'] >= multiplicity_threshold:
                current_score = bck[config]
                break

        return -np.log(bts(current_score) + 1e-8).item()
        
    scores = [
        np.mean([_get_score(item, i, config='fine') for item in data])
        for i in range(100)
    ]
    
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(
        range(100),
        scores,
        label='CPMI',
        color='blue',
        marker='o',
        markersize=5,
        markerfacecolor='blue',
        markeredgecolor='white',
        linewidth=2,
    )
    
    # remove upper and right spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    ax.set_xlabel('Multiplicity Threshold', fontsize=14)
    ax.set_ylabel('CPMI', fontsize=14)
    
    fig.tight_layout()
    fig.savefig(
        os.path.join(
            "data",
            "conformal-backoff",
            os.path.basename(input_path).replace('.jsonl', '_cpmi.png')
        ),
        dpi=300,
    )
    

if __name__ == "__main__":
    main()