import collections
import json
from pprint import pprint
from typing import List, Optional

import numpy as np
from tqdm import tqdm
from rouge import Rouge
import re
from sentence_transformers import SentenceTransformer, util

from util.globals import *


def main(
    dir_name,
    runs: Optional[List],
    first_n_cases=None,
    get_uncompressed=False,
    abs_path=False,
):  # runs = None -> all runs
    summaries = []
    uncompressed = []
    rouge = Rouge()
    bert = SentenceTransformer('all-MiniLM-L6-v2', device='cuda:0')

    for run_dir in (RESULTS_DIR / dir_name if not abs_path else dir_name).iterdir():
        # Skip if we're not interested
        if runs is not None and all(run not in str(run_dir) for run in runs):
            continue

        # Iterate through all case files
        cur_sum = collections.defaultdict(lambda: [])
        case_detail = []
        files = list(run_dir.glob("*case_*.json"))
        files.sort(key=lambda x: int(str(x).split("_")[-1].split(".")[0]))
        for case_file in tqdm(files, ncols=100):
            try:
                with open(case_file, "r") as f:
                    data = json.load(f)
            except json.JSONDecodeError:
                print(f"Could not decode {case_file} due to format error; skipping.")

            case_id = data["case_id"]
            if first_n_cases is not None and case_id >= first_n_cases:
                break

            if "time" in data:
                cur_sum["time"].append(data["time"])

            # data['generation_result'] = data['generation_result'][2:]
            answers = [item[1] for item in data['generation_result']]
            predictions = [item[3].split('<|eot_id|>')[0].split('<|im_end|>')[0] for item in data['generation_result']]

            try:
                rouge_ori = rouge.get_scores(
                    [" ".join(re.findall(r'\b\w+\b', p.lower())) for p in predictions],
                    [" ".join(re.findall(r'\b\w+\b', a.lower())) for a in answers]
                )
            except Exception:
                print('ROUGE Compute Error:', case_file)
                continue

            try:
                ans_emb = bert.encode(answers, convert_to_tensor=True)
                pred_emb = bert.encode(predictions, convert_to_tensor=True)
                bert_scores_each = util.cos_sim(ans_emb, pred_emb).diagonal().cpu().numpy().tolist()
            except Exception as e:
                print(e, 'BERT-Score Compute Error:', case_file)
                continue

            cur_sum['ROUGE-l-Precision'].append(np.mean([each['rouge-l']['p'] for each in rouge_ori]))
            cur_sum['ROUGE-l-Recall'].append(np.mean([each['rouge-l']['r'] for each in rouge_ori]))
            cur_sum['ROUGE-l-F1'].append(np.mean([each['rouge-l']['f'] for each in rouge_ori]))
            cur_sum['BERT-Score'].append(np.mean(bert_scores_each))
            cur_sum['MMLU'].append(data.get('mmlu_result', -1))
            cur_sum['Neighborhood'].append(data.get('neighborhood_result', -1))

            case_data = {'case_id': case_id, 'text': data['text'], 'info': []}
            # questions = [item[0].split('Question:')[-1].replace('Answer:', '').strip() for item in data['QA']]
            questions = [item[0] for item in data['generation_result']]
            for each, que, ans, pred, bert_score in zip(rouge_ori, questions, answers, predictions, bert_scores_each):
                case_data['info'].append({
                    'question': que,
                    'answer': ans,
                    'prediction': pred,
                    'rouge-l-p': each['rouge-l']['p'],
                    'rouge-l-r': each['rouge-l']['r'],
                    'rouge-l-f': each['rouge-l']['f'],
                    'bert-score': bert_score,
                })
            case_detail.append(case_data)

        if len(cur_sum) == 0:
            continue

        num_items = len(cur_sum[next(iter(cur_sum.keys()))])
        metadata = {
            "run_dir": str(run_dir),
            "num_cases": num_items,
        }

        uncompressed.append(dict(cur_sum, **metadata))

        cur_sum = {k: (np.mean(v), np.std(v)) for k, v in cur_sum.items()}
        for k, v in cur_sum.items():
            if all(exclude not in k for exclude in ["essence_score", "time"]):
                # Constant multiplication scales linearly with mean and stddev
                cur_sum[k] = tuple(np.around(z * 100, 2).item() for z in v)

        cur_sum.update(metadata)
        pprint(cur_sum)
        summaries.append(cur_sum)

        with open(run_dir / '_result.json', 'w', encoding='utf-8') as f:
            json.dump([cur_sum, case_detail], f, indent=4, ensure_ascii=False)

    return uncompressed if get_uncompressed else summaries


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dir_name", type=str, help="Name of directory to scan for runs."
    )
    parser.add_argument(
        "--runs",
        type=str,
        default=None,
        help="By default, summarizes each run in <dir_name>. "
        "If runs are specified, only evaluates those specific runs.",
    )
    parser.add_argument(
        "--first_n_cases",
        type=int,
        default=None,
        help="Restricts evaluation to first n cases in dataset. "
        "Useful for comparing different in-progress runs on the same slice of data.",
    )
    args = parser.parse_args()

    main(
        args.dir_name,
        None if args.runs is None else args.runs.split(","),
        args.first_n_cases,
    )
