""" """

import glob
import numpy as np
import random
import os
import matplotlib.pyplot as plt
import ujson as json
from overrides import overrides
from typing import Text, Dict, Any, List, Tuple
from tasker import BaseTask
from ..data_readers.simpleqa import (
    SimpleQAScoringDataReader,
    SimpleQAScoredQuestion,
)
from ..utils.bounds import hb_p_value


# use nimbus roman no9 l font
# plt.rcParams['font.family'] = 'Nimbus Roman'
plt.rc('font', weight='bold')


@BaseTask.register("nq-ltt")
class NQLTT(BaseTask):
    
    __VERSION__ = "0.1.10"

    def __init__(
        self,
        input_dir: Text,
        cflm_dir: Text,
        raw_backoff_input_dir: Text,
        output_dir: Text,
        cal_ratio: float = 0.5,
        seed: int = 31,
        significance_level: float = 0.05,
    ):
        """ """
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        self._cflm_dir = cflm_dir
        self._raw_backoff_input_dir = raw_backoff_input_dir
        self._cal_ratio = cal_ratio
        self._seed = seed
        self._alphas = [alpha.item() for alpha in np.arange(0.95, 0.25, -0.005, dtype=np.float32)]
        self._threshold_to_test = list(range(100))
        self._significance_level = significance_level
        
    @overrides
    def _run(self):
        """ """
        
        iterator = list(SimpleQAScoringDataReader(
            [filepath for filepath in glob.glob(os.path.join(self._input_dir, "*.jsonl"))]
        ))
        
        raw_backoff_iterator = list(SimpleQAScoringDataReader(
            [filepath for filepath in glob.glob(os.path.join(self._raw_backoff_input_dir, "*.jsonl"))]
        ))
        
        def _split(
            ratio: float,
            dataset: List[SimpleQAScoredQuestion],
            seed: int
        ) -> Tuple[List[SimpleQAScoredQuestion], List[SimpleQAScoredQuestion]]:
            """ """
            answer_type_grouped = {}
            random_obj = random.Random(seed)
            
            for item in dataset:
                if item.answer_type not in answer_type_grouped:
                    answer_type_grouped[item.answer_type] = []

                answer_type_grouped[item.answer_type].append(item)
                
            # now split with ratio
            cal = []
            test = []
            for answer_type, items in answer_type_grouped.items():
                cal_size = int(len(items) * ratio)
                random_obj.shuffle(items)
                cal.extend(items[:cal_size])
                test.extend(items[cal_size:])

            return cal, test
        
        cal, test = _split(self._cal_ratio, iterator, self._seed)
        test_indices_set = {item.index for item in test}
        raw_test = [item for item in raw_backoff_iterator if item.index in test_indices_set]
        
        result_grouped_by_threshold = []

        for threshold in self._threshold_to_test:
            # gather the data
            gathered = []
            test_gathered = []
            for item in cal:
                item: SimpleQAScoredQuestion
                sorted_backoffs = sorted(item.backoffs, key=lambda x: x.multiplicity, reverse=False)

                # find the score corresponding to the backoff that has the least multiplicity above the threshold
                for backoff in sorted_backoffs:
                    if backoff.multiplicity > threshold:
                        gathered.append(backoff.score)
                        break

            for item in test:
                item: SimpleQAScoredQuestion
                sorted_backoffs = sorted(item.backoffs, key=lambda x: x.multiplicity, reverse=False)

                # find the score corresponding to the backoff that has the least multiplicity above the threshold
                for backoff in sorted_backoffs:
                    if backoff.multiplicity > threshold:
                        test_gathered.append(backoff.score)
                        break
            
            result_grouped_by_alpha = []
            for alpha in self._alphas:
                p_val = hb_p_value(
                    r_hat=1 - np.mean(gathered),
                    n=len(gathered),
                    alpha=alpha,
                )
                
                result = {
                    "threshold": threshold,
                    "p_val": p_val,
                    "alpha": alpha,
                    "reject_null": 1 if p_val < self._significance_level else 0,
                }

                result_grouped_by_alpha.append(result)

            result_grouped_by_threshold.append({
                "threshold": threshold,
                "control": result_grouped_by_alpha,
                "test_accuracy": np.mean(test_gathered),
            })
            
        # create fig
        fig, ax = plt.subplots()
        
        test_accuracies = []
        grouped_controlled_alphas = []
        thresholds = []

        for threshold_group in result_grouped_by_threshold:
            test_accuracies.append(threshold_group['test_accuracy'])
            controlled_alpha = []
            for item in threshold_group['control']:
                if item['reject_null'] == 1:
                    controlled_alpha.append(1 - item['alpha'])
            
            grouped_controlled_alphas.append(controlled_alpha)
                    
            thresholds.append(threshold_group['threshold'])
            
        # group by threshold for raw_backoff_test
        multiplicity_set = set()
        for item in raw_test:
            multiplicity_set.update({backoff.multiplicity for backoff in item.backoffs})
            
        multiplicity_set = sorted(multiplicity_set)
        m_paired_results = []
            
        for raw_targ_mul in multiplicity_set:
            examination = []
            for item in raw_test:
                for bf in item.backoffs:
                    if bf.multiplicity == raw_targ_mul:
                        examination.append(bf.score)
                        break
            
            m_paired_results.append(np.mean(examination).item())
            
        ax.plot(np.array(thresholds) + 1, test_accuracies, label="CLC", color='b')
        # ax.plot(
        #     [1] + multiplicity_set, test_accuracies[:1] + m_paired_results,
        #     color="grey",
        #     linestyle='--',
        #     alpha=0.8
        # )
        # ax.scatter(
        #     [1] + multiplicity_set, test_accuracies[:1] + m_paired_results,
        #     label="Naive Acc",
        #     color="#F5BD1E",
        #     marker='o', s=200, edgecolors='white', linewidths=2,
        #     zorder=100
        # )
        
        with open(
            os.path.join(self._cflm_dir, "conformal_factual_guarantee.json"), "r", encoding='utf-8'
        ) as file_:
            cflm_data = json.load(file_)
            
        abstention_ratio = cflm_data["abstention_ratios"]
        accuracies = cflm_data["accuracies"]
        
        ax.plot(
            np.array(abstention_ratio) * 100,
            np.array(accuracies),
            label="Abstention",
            color="orange"
        )
        ax.fill_between(np.array(thresholds) + 1, [max(g) for g in grouped_controlled_alphas], [1.] * len(grouped_controlled_alphas), color='b', alpha=0.1, label="Guarantee")

        ax.set_ylim([.25, .701])
        ax.set_yticks([0.3, 0.4, 0.5, 0.6, 0.7])
        ax.set_xlim([1, 100])
        ax.set_xticks([1, 20, 40, 60, 80, 100])
        # set tick label size
        ax.tick_params(axis='both', which='major', labelsize=20)
        ax.set_xlabel("Percentage (%)", fontsize=20, fontweight='bold')
        ax.set_ylabel("Factuality (NQ)", fontsize=20, fontweight='bold')

        ax.legend(loc='lower right', fontsize=20, prop={'weight': 'bold', 'size': 20})
        fig.tight_layout()
        
        return (result_grouped_by_threshold, cal, test, fig)
    
    @overrides
    def _write(self, outputs):
        with open(os.path.join(self._output_dir, "ltt_results.json"), "w", encoding='utf-8') as file_:
            json.dump(outputs[0], file_, ensure_ascii=False, indent=2)
            
        with open(os.path.join(self._output_dir, "calibration.jsonl"), "w", encoding='utf-8') as file_:
            for item in outputs[1]:
                file_.write(json.dumps({
                    "index": item.index,
                    "question": item.question,
                    "gold_answer": item.gold_answer,
                    "answer_type": item.answer_type,
                    "backoffs": [
                        {
                            "score": backoff.score,
                            "backoff": backoff.backoff,
                            "multiplicity": backoff.multiplicity,
                        } for backoff in item.backoffs
                    ]
                }) + "\n")
                
        with open(os.path.join(self._output_dir, "test.jsonl"), "w", encoding='utf-8') as file_:
            for item in outputs[2]:
                file_.write(json.dumps({
                    "index": item.index,
                    "question": item.question,
                    "gold_answer": item.gold_answer,
                    "answer_type": item.answer_type,
                    "backoffs": [
                        {
                            "score": backoff.score,
                            "backoff": backoff.backoff,
                            "multiplicity": backoff.multiplicity,
                        } for backoff in item.backoffs
                    ]
                }) + "\n")
                
        outputs[3].savefig(os.path.join(self._output_dir, "ltt_results.pdf"))