""" """

import glob
import numpy as np
import random
import os
import matplotlib.pyplot as plt
import ujson as json
from overrides import overrides
from typing import Text, Dict, Any, List, Tuple
from tasker import BaseTask
from matplotlib.patches import Circle
import matplotlib.transforms as transforms
from ..data_readers.simpleqa import (
    SimpleQAScoringDataReader,
    SimpleQAScoredQuestion,
)
from ..utils.bounds import hb_p_value


# use nimbus roman no9 l font
# plt.rcParams['font.family'] = 'Nimbus Roman'
plt.rc('font', weight='bold')


@BaseTask.register("simpleqa-ltt")
class SimpleQALTT(BaseTask):
    
    __VERSION__ = "0.4.10"

    def __init__(
        self,
        input_dir: Text,
        cflm_dir: Text,
        output_dir: Text,
        cal_ratio: float = 0.5,
        seed: int = 31,
        significance_level: float = 0.05,
    ):
        """ """
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        self._cflm_dir = cflm_dir
        self._cal_ratio = cal_ratio
        self._seed = seed
        self._alphas = [alpha.item() for alpha in np.arange(0.95, 0.55, -0.005, dtype=np.float32)]
        self._threshold_to_test = list(range(100))
        self._significance_level = significance_level
        
    @overrides
    def _run(self):
        """ """
        
        iterator = list(SimpleQAScoringDataReader(
            [filepath for filepath in glob.glob(os.path.join(self._input_dir, "*.jsonl"))]
        ))
        
        def _split(
            ratio: float,
            dataset: List[SimpleQAScoredQuestion],
            seed: int
        ) -> Tuple[List[SimpleQAScoredQuestion], List[SimpleQAScoredQuestion]]:
            """ """
            answer_type_grouped = {}
            random_obj = random.Random(seed)
            
            for item in dataset:
                if item.answer_type not in answer_type_grouped:
                    answer_type_grouped[item.answer_type] = []

                answer_type_grouped[item.answer_type].append(item)
                
            # now split with ratio
            cal = []
            test = []
            for answer_type, items in answer_type_grouped.items():
                cal_size = int(len(items) * ratio)
                random_obj.shuffle(items)
                cal.extend(items[:cal_size])
                test.extend(items[cal_size:])

            return cal, test
        
        cal, test = _split(self._cal_ratio, iterator, self._seed)
        
        result_grouped_by_threshold = []

        for threshold in self._threshold_to_test:
            # gather the data
            gathered = []
            test_gathered = []
            for item in cal:
                item: SimpleQAScoredQuestion
                sorted_backoffs = sorted(item.backoffs, key=lambda x: x.multiplicity, reverse=False)

                # find the score corresponding to the backoff that has the least multiplicity above the threshold
                for backoff in sorted_backoffs:
                    if backoff.multiplicity > threshold:
                        gathered.append(backoff.score)
                        break

            for item in test:
                item: SimpleQAScoredQuestion
                sorted_backoffs = sorted(item.backoffs, key=lambda x: x.multiplicity, reverse=False)

                # find the score corresponding to the backoff that has the least multiplicity above the threshold
                for backoff in sorted_backoffs:
                    if backoff.multiplicity > threshold:
                        test_gathered.append(backoff.score)
                        break
            
            result_grouped_by_alpha = []
            for alpha in self._alphas:
                p_val = hb_p_value(
                    r_hat=1 - np.mean(gathered),
                    n=len(gathered),
                    alpha=alpha,
                )
                
                result = {
                    "threshold": threshold,
                    "p_val": p_val,
                    "alpha": alpha,
                    "reject_null": 1 if p_val < self._significance_level else 0,
                }

                result_grouped_by_alpha.append(result)

            result_grouped_by_threshold.append({
                "threshold": threshold,
                "control": result_grouped_by_alpha,
                "test_accuracy": np.mean(test_gathered),
            })
            
        # create fig
        fig, ax = plt.subplots()
        
        test_accuracies = []
        grouped_controlled_alphas = []
        thresholds = []

        for threshold_group in result_grouped_by_threshold:
            test_accuracies.append(threshold_group['test_accuracy'])
            controlled_alpha = []
            for item in threshold_group['control']:
                if item['reject_null'] == 1:
                    controlled_alpha.append(1 - item['alpha'])
            
            grouped_controlled_alphas.append(controlled_alpha)
                    
            thresholds.append(threshold_group['threshold'])

        ax.plot(np.array(thresholds) + 1, test_accuracies, label="CLC", color='b')
        
        # TODO: load data from cflm_dir, and plot the data
        with open(
            os.path.join(self._cflm_dir, "conformal_factual_guarantee.json"), "r", encoding='utf-8'
        ) as file_:
            cflm_data = json.load(file_)
            
        abstention_ratio = cflm_data["abstention_ratios"]
        accuracies = cflm_data["accuracies"]
        
        ax.plot(
            np.array(abstention_ratio) * 100,
            np.array(accuracies),
            label="Abstention",
            color="orange"
        )
        
        # add a grey dashed line at y = 38.2% for GPT-4o performance
        ax.axhline(y=0.382, color='grey', linestyle='--', label="GPT-4o")
        ax.fill_between(np.array(thresholds) + 1, [max(g) for g in grouped_controlled_alphas], [1.] * len(grouped_controlled_alphas), color='b', alpha=0.1, label="Guarantee")

        # ax.set_ylim([.25, .701])
        # ax.set_yticks([0.3, 0.4, 0.5, 0.6, 0.7])
        ax.set_ylim([0.05, 0.501])
        ax.set_yticks([0.1, 0.2, 0.3, 0.4, 0.5])
        ax.set_xlim([1, 100])
        ax.set_xticks([1, 20, 40, 60, 80, 100])
        # set tick label size
        
        ax.tick_params(axis='both', which='major', labelsize=20)
        ax.set_xlabel("Percentage (%)", fontsize=20, fontweight='bold')
        ax.set_ylabel("Factuality (SimpleQA)", fontsize=20, fontweight='bold')
        ax.legend(loc='upper left', fontsize=20, prop={'weight': 'bold', 'size': 20})
        fig.tight_layout()
        
        ax.scatter([20], [test_accuracies[19]], color='#FFF347', s=200, edgecolors='white', linewidth=2, zorder=100)
        ax.scatter([40], [test_accuracies[39]], color='#F5BD1E', s=200, edgecolors='white', linewidth=2, zorder=100)
        ax.scatter([80], [test_accuracies[79]], color='#7BB662', s=200, edgecolors='white', linewidth=2, zorder=100)
        
        return (result_grouped_by_threshold, cal, test, fig)
    
    @overrides
    def _write(self, outputs):
        with open(os.path.join(self._output_dir, "ltt_results.json"), "w", encoding='utf-8') as file_:
            json.dump(outputs[0], file_, ensure_ascii=False, indent=2)
            
        with open(os.path.join(self._output_dir, "calibration.jsonl"), "w", encoding='utf-8') as file_:
            for item in outputs[1]:
                file_.write(json.dumps({
                    "index": item.index,
                    "question": item.question,
                    "gold_answer": item.gold_answer,
                    "answer_type": item.answer_type,
                    "backoffs": [
                        {
                            "score": backoff.score,
                            "backoff": backoff.backoff,
                            "multiplicity": backoff.multiplicity,
                        } for backoff in item.backoffs
                    ]
                }) + "\n")
                
        with open(os.path.join(self._output_dir, "test.jsonl"), "w", encoding='utf-8') as file_:
            for item in outputs[2]:
                file_.write(json.dumps({
                    "index": item.index,
                    "question": item.question,
                    "gold_answer": item.gold_answer,
                    "answer_type": item.answer_type,
                    "backoffs": [
                        {
                            "score": backoff.score,
                            "backoff": backoff.backoff,
                            "multiplicity": backoff.multiplicity,
                        } for backoff in item.backoffs
                    ]
                }) + "\n")
                
        outputs[3].savefig(os.path.join(self._output_dir, "ltt_results.pdf"))
        outputs[3].savefig(os.path.join(self._output_dir, "ltt_results.svg"))