""" Task that plots the relationship for trained claim rewriting. """

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from glob import glob
from sklearn.linear_model import LinearRegression
from functools import lru_cache
from dataclasses import dataclass
from overrides import overrides

try:
    import ujson as json
except ImportError:
    import json
import os
from typing import List, Tuple, Text, Dict, Any, Union
from tasker import BaseTask


model_name_to_marker = {
    # "base": "o",
    "DPO": "s",
    "SFT": "^",
    "ORPO": "h",
}

model_name_to_color = {
    "DPO": "blue",
    "SFT": "red",
    "ORPO": "orange",
}

plt.rc('font', weight='bold')


@dataclass(frozen=True, eq=True)
class RewritingResult:
    level: int
    model: Text
    frequency: Text
    is_abstention: bool
    score: float


@BaseTask.register("investigate-trained-rewriting")
class InvetigateTrainedRewritingTask(BaseTask):
    """ """
    
    __VERSION__ = '0.2.5'

    def __init__(
        self,
        input_dirs: List[Tuple[Text, Text]],
        verbalization_level: List[Tuple[Text, int]],
        base_result_dir: Text,
        frequency_map_path: Union[Text, List[Text]],
        output_dir: Text,
    ):
        """ """
        super().__init__(output_dir=output_dir)
        self._input_dirs = input_dirs
        self._base_result_dir = base_result_dir
        with open(os.path.join(self._base_result_dir, "agg_scores.json"), 'r', encoding='utf-8') as file_:
            self._base_average_score = json.load(file_)['average_score']
        self._frequency_map_paths = (
            frequency_map_path
            if not isinstance(frequency_map_path, str)
            else [frequency_map_path]
        )
        print('-' * 50)
        print(self._frequency_map_paths)
        print('-' * 50)
        self._verbalization_level = verbalization_level

        # load frequency map
        self._freq_map = {}
        for fp in self._frequency_map_paths:
            with open(fp, "r", encoding="utf-8") as file_:
                self._freq_map.update(json.load(file_))

    def _breakdown_by_level(self, results: List[RewritingResult]) -> mpl.figure.Figure:
        """ """
        
        fig = plt.figure()

        verbalized_levels = sorted(
            self._verbalization_level, key=lambda x: x[1], reverse=False
        )

        for model_name, _ in self._input_dirs:
            model_results = []
            for _, level in verbalized_levels:
                # filter with model + level
                model_results.append(
                    np.mean(
                        [
                            ins.score
                            for ins in filter(
                                lambda x: x.model == model_name and x.level == level and not x.is_abstention,
                                results,
                            )
                        ]
                    ).item()
                )
        
            if model_name == 'base':
                # plt.plot(
                #     list(range(1, len(verbalized_levels) + 1)),
                #     model_results,
                #     label=model_name,
                #     linestyle="--",
                #     color='grey',
                #     marker="o",
                #     markersize=10,
                #     markeredgewidth=2,
                #     markeredgecolor='white',
                # )
                # plt.scatter(
                #     list(range(1, len(verbalized_levels) + 1)),
                #     model_results,
                #     label=model_name,
                #     color='grey',
                #     marker="o",
                #     s=100,
                #     edgecolors='white',
                #     linewidths=2,
                # )
                reg = LinearRegression().fit(
                    np.array(list(range(1, len(verbalized_levels) + 1))).reshape(-1, 1),
                    np.array(model_results),
                )
                
                plt.plot(
                    list(range(1, len(verbalized_levels) + 1)),
                    reg.predict(np.array(list(range(1, len(verbalized_levels) + 1))).reshape(-1, 1)),
                    linestyle="-",
                    color='grey',
                    lw=3,
                    label=model_name
                )
                
            else:
                # plt.plot(
                #     list(range(1, len(verbalized_levels) + 1)),
                #     model_results,
                #     label=model_name,
                #     linestyle="--",
                #     marker=model_name_to_marker[model_name],
                #     color=model_name_to_color[model_name],
                #     markersize=10,
                #     markeredgewidth=2,
                #     markeredgecolor='white',
                # )
                reg = LinearRegression().fit(
                    np.array(list(range(1, len(verbalized_levels) + 1))).reshape(-1, 1),
                    np.array(model_results),
                )
                
                plt.plot(
                    list(range(1, len(verbalized_levels) + 1)),
                    reg.predict(np.array(list(range(1, len(verbalized_levels) + 1))).reshape(-1, 1)),
                    linestyle="-",
                    color=model_name_to_color[model_name],
                    lw=3,
                    label=model_name
                )

        # plot a horizontal line for the average score base
        plt.axhline(self._base_average_score, color='grey', linestyle='--')
            
        plt.xticks(list(range(1, len(verbalized_levels) + 1)), [x[0] for x in verbalized_levels], fontsize=20)
        plt.yticks([0.60, 0.7, 0.8], [0.6, 0.7, 0.8], fontsize=20, fontweight='bold')
        plt.ylim([0.59, 0.85])
        plt.xlabel("Controlled Confidence Level", fontsize=20, fontweight='bold')
        plt.ylabel("FActScore", fontsize=20, fontweight='bold')
        plt.legend(fontsize=20, loc="lower right", prop={'weight': 'bold', 'size': 20})
        
        fig.tight_layout()
        
        return fig
    
    def _breakdown_by_frequency(
        self,
        results: List[RewritingResult],
        at_level: int,
    ) -> mpl.figure.Figure:
        """ """
        
        # create bar plot for each model grouped by frequency
        fig = plt.figure()
        num_models = len(self._input_dirs)
        
        for midx, (model_name, _) in enumerate(self._input_dirs):
            model_results = []
            for freq in [
                "very rare",
                "rare",
                "medium",
                "frequent",
            ]:
                model_results.append(
                    np.mean(
                        [
                            ins.score
                            for ins in filter(
                                lambda x: x.model == model_name and x.frequency == freq and not x.is_abstention and x.level == at_level,
                                results,
                            )
                        ]
                    ).item()
                )

            plt.bar(
                np.arange(4) + (1 + midx) / (num_models + 2),
                model_results,
                width=1 / (num_models + 2),
                label=model_name,
            )

        plt.xticks(np.arange(4) + 0.5, ["very rare", "rare", "medium", "frequent"], fontsize=16)
        plt.xlabel("Freq", fontsize=16)
        plt.ylabel("FActScore", fontsize=16)
        plt.legend(fontsize=16, loc="right")
        
        fig.tight_layout()
        
        return fig

    @overrides
    def _run(self):
        """ """

        @lru_cache(maxsize=128)
        def _get_item(input_dir: Text) -> List[Dict[Text, Any]]:
            """ """

            items = []

            for filepath in glob(os.path.join(input_dir, "*.jsonl")):
                with open(filepath, "r", encoding="utf-8") as file_:
                    for line in file_:
                        items.append(json.loads(line))

            return items

        results = []
        for model_name, input_dir in self._input_dirs:
            for item in _get_item(input_dir):
                level = int(item["id_"].split("-")[1])
                score = item["aggregated_score"]
                topic = item["meta"]["topic"]

                results.append(
                    RewritingResult(
                        level=level,
                        model=model_name,
                        frequency=self._freq_map[topic],
                        is_abstention=item['meta']['is_abstention'],
                        score=score,
                    )
                )
                
        fig_level = self._breakdown_by_level(results)

        fig_freqs = {
            verbalization: self._breakdown_by_frequency(results, at_level=level)
            for verbalization, level in self._verbalization_level
        }
        
        return fig_level, fig_freqs
    
    @overrides
    def _write(self, outputs):
        """ """
        
        fig_level, fig_freqs = outputs
        fig_level.savefig(os.path.join(self._output_dir, "level.pdf"))
        plt.close(fig_level)

        for verbalization, fig in fig_freqs.items():
            fig.savefig(os.path.join(self._output_dir, f"freq_{verbalization}.pdf"))
            plt.close(fig)