import csv
import argparse
import os
import re
import sys
import evaluate
import time
import math
import json
from pathlib import Path

import pandas as pd
from tqdm import tqdm
import numpy as np
from more_itertools import batched
from multiprocessing import Pool, set_start_method
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
sys.path.insert(0, str(Path(__file__).parent))
from chexpert_labeler import ChexpertLabeler
from chexpert_labeler.constants import CATEGORIES
from chexbert_labeler import ChexbertLabeler

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--result_dir",
        type=str,
        help="path to result dir containing txt result",
        default="results"
    )
    parser.add_argument(
        "--annotation",
        type=str,
        help="path to annotation table",
        default="annotation.csv"
    )

    return parser.parse_args()

def clean_report(report):
    report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
        .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace('  ', ' ') \
        .replace('  ', ' ').replace('  ', ' ').replace('  ', ' ').replace('  ', ' ').replace('  ', ' ') \
        .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
        .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
        .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
        .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
        .strip().lower().split('. ')
    sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '')
                                    .replace('\\', '').replace("'", '').strip().lower())
    tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
    report = ' . '.join(tokens) + ' .'
    return report



def calculate_nlg_metrics(preds, gts, study_ids=None, all_results=None):
    # BLEU, METEOR, ROUGE-L, Perplexity
    print("Calculating NLG metrics")
    results = {}
    bleu = evaluate.load("bleu")
    gts_bleu = [[gt] for gt in gts]
    for n in range(1, 5):
        result = bleu.compute(predictions=preds, references=gts_bleu, max_order=n)
        results[f"BLEU@{n}"] = result["bleu"]
    meteor = evaluate.load("meteor")
    result = meteor.compute(predictions=preds, references=gts)
    results["METEOR"] = result["meteor"]
    rouge = evaluate.load("rouge")
    result = rouge.compute(predictions=preds, references=gts)
    results["ROUGE_L"] = result["rougeL"]
    perplexity = evaluate.load("perplexity", module_type="metric")
    # result = perplexity.compute(predictions=preds, model_id="gpt2")
    # results["Perplexity"] = result["mean_perplexity"]
    if all_results is not None:
        for id, (pred, gt) in enumerate(zip(preds, gts)):
            if study_ids is not None:
               id = study_ids[id]
            all_results[id] = {}
            if len(pred) == 0: pred = "."
            if len(gt) == 0: gt = "."
            for n in range(1, 5):
                result = bleu.compute(predictions=[pred], references=[gt], max_order=n)
                all_results[id][f"BLEU@{n}"] = result["bleu"]
        for metric_name in ["meteor", "rouge"]:
            metric = evaluate.load(metric_name)
            if metric_name == "meteor":
                metric_name = "METEOR"
                key_name = "meteor"
            elif metric_name == "rouge":
                metric_name = "ROUGE_L"
                key_name = "rougeL"
            else:
                raise NotImplementedError
            for id, (pred, gt) in enumerate(zip(preds, gts)):
                if study_ids is not None:
                    id = study_ids[id]
                result = metric.compute(predictions=[pred], references=[gt])
                all_results[id][metric_name] = result[key_name]
    return results

def calculate_chexpert_metrics(preds, gts, study_ids=None, all_results=None, num_processes=None):
    print("Calculating CheXpert metrics")
    labeler = ChexpertLabeler()
    preds_labels, gts_labels = [], []
    for pred, gt in tqdm(zip(preds, gts), total=len(preds)):
        pred_label = labeler.get_label(pred)
        gt_label = labeler.get_label(gt)
        preds_labels.append([1 if pred_label[k] == 1 else 0 for k in CATEGORIES])
        gts_labels.append([1 if gt_label[k] == 1 else 0 for k in CATEGORIES])
    if all_results is not None:
        for id, (pred, gt) in enumerate(zip(preds_labels, gts_labels)):
            if study_ids is not None:
                id = study_ids[id]
            for pred_label, cat in zip(pred, CATEGORIES):
                all_results[id][f"{cat}_PRED"] = pred_label
            for gt_label, cat in zip(gt, CATEGORIES):
                all_results[id][f"{cat}_GT"] = gt_label
    precision, recall, f1_score, support = precision_recall_fscore_support(gts_labels, preds_labels, average="micro")
    precision_cls, recall_cls, f1_score_cls, support_cls = precision_recall_fscore_support(gts_labels, preds_labels, average=None)
    result = {f"CHEXPERT_PRECISION": precision, "CHEXPERT_RECALL": recall, "CHEXPERT_F1_SCORE": f1_score, "SUPPORT": support}
    for p,r,f,s,c in zip(precision_cls, recall_cls, f1_score_cls, support_cls, CATEGORIES):
        result[c] = {"PRECISION": p, "RECALL": r, "F1_SCORE": f, "SUPPORT": int(s)}
    return result


def calculate_chexpert_metrics_mp(preds, gts, study_ids=None, all_results=None, num_processes=os.cpu_count()-1):
    print("Calculating CheXpert metrics")
    preds_gts = preds + gts
    num_processes = min(num_processes, len(preds_gts))
    preds_gts_chunks = list(batched(preds_gts, math.ceil(len(preds_gts)/num_processes)))
    if num_processes != len(preds_gts_chunks): num_processes = len(preds_gts_chunks)
    pool = Pool(num_processes)
    preds_gts_labels = pool.imap(get_label, preds_gts_chunks)
    pool.close()
    pool.join()
    preds_gts_labels = np.concatenate(list(preds_gts_labels))
    preds_labels, gts_labels = preds_gts_labels[:len(preds)], preds_gts_labels[len(preds):]
    if all_results is not None:
        for id, (pred, gt) in enumerate(zip(preds_labels, gts_labels)):
            if study_ids is not None:
                id = study_ids[id]
            for pred_label, cat in zip(pred, CATEGORIES):
                all_results[id][f"{cat}_PRED"] = pred_label
            for gt_label, cat in zip(gt, CATEGORIES):
                all_results[id][f"{cat}_GT"] = gt_label
    precision, recall, f1_score, support = precision_recall_fscore_support(gts_labels, preds_labels, average="micro")
    precision_cls, recall_cls, f1_score_cls, support_cls = precision_recall_fscore_support(gts_labels, preds_labels, average=None)
    result = {f"CHEXPERT_PRECISION": precision, "CHEXPERT_RECALL": recall, "CHEXPERT_F1_SCORE": f1_score, "SUPPORT": support}
    for p, r, f, s, c in zip(precision_cls, recall_cls, f1_score_cls, support_cls, CATEGORIES):
        result[c] = {"PRECISION": p, "RECALL": r, "F1_SCORE": f, "SUPPORT": int(s)}
    return result


def calculate_chexbert_metrics(preds, gts):
    print("Calculating CheXbert metrics")
    labeler = ChexbertLabeler()
    labeler = labeler.to(labeler.device)
    def mini_batch(iterable, mbatch_size=1):
        length = len(iterable)
        for i in range(0, length, mbatch_size):
            yield iterable[i:min(i + mbatch_size, length)]

    table = {'chexbert_y_hat': [], 'chexbert_y': [], 'y_hat': [], 'y': []}
    pairs = [(pred, gt) for pred, gt in zip(preds, gts)]
    for i in tqdm(mini_batch(pairs, 8)):
        y_hat, y = zip(*i)
        table['chexbert_y_hat'].extend([i for i in labeler(list(y_hat)).tolist()])
        table['chexbert_y'].extend([i for i in labeler(list(y)).tolist()])
        table['y_hat'].extend(y_hat)
        table['y'].extend(y)

    y_hat_list = table["chexbert_y_hat"]
    y_list = table["chexbert_y"]
    y_hat_list = [[1 if i == 1 else 0 for i in y_hat] for y_hat in y_hat_list]
    y_list = [[1 if i == 1 else 0 for i in y] for y in y_list]
    precision, recall, f1_score, _ = precision_recall_fscore_support(y_list, y_hat_list, average="micro")
    # precision_cls, recall_cls, f1_score_cls, _ = precision_recall_fscore_support(y_list, y_hat_list, average=None)
    scores = {
        'CHEXBERT_PRECISION': precision,
        'CHEXBERT_RECALL': recall,
        'CHEXBERT_F1_SCORE': f1_score,
    }
    return scores


def get_label(report_list):
    labeler = ChexpertLabeler()
    label_list = []
    for report in tqdm(report_list):
        try:
            label = labeler.get_label(report)
            label = [1 if label[k] == 1 else 0 for k in CATEGORIES]
        except Exception as e:
            print(e)
            label = [0] * len(CATEGORIES)
        label_list.append(label)
    del labeler
    return np.array(label_list)



def eval_result_dir(result_dir, valid_ids=None):
    result_dir = str(result_dir)
    preds, gts = [], []
    study_ids = []
    for file_name in os.listdir(result_dir):
        if not file_name.endswith(".txt"): continue
        study_id = os.path.splitext(file_name)[0]
        if study_id not in valid_ids: continue
        file_path = os.path.join(result_dir, file_name)
        report_string = ""
        f = open(file_path, 'r')
        for line in f.readlines():
            line = line.strip('\t ')
            report_string += line
        f.close()
        pred_idx = report_string.find("PRED:")
        gt_idx = report_string.find("GT:")
        # if pred_idx < 0 or gt_idx < 0: continue
        pred = report_string[pred_idx:gt_idx].replace("PRED:", "").lower().strip()
        gt = report_string[gt_idx:].replace("GT:", "").lower().strip()
        preds.append(pred)
        gts.append(gt)
        study_ids.append(Path(file_path).stem)
    all_result = {}
    nlg_metrics = calculate_nlg_metrics(preds, gts, study_ids, all_result)
    print(nlg_metrics)
    chexbert_metrics = calculate_chexbert_metrics(preds, gts)
    print(json.dumps(chexbert_metrics, indent=2))
    chexpert_metrics = calculate_chexpert_metrics(preds, gts, study_ids, all_result, num_processes=48)
    print(json.dumps(chexpert_metrics, indent=2))
    results = {**nlg_metrics, **chexbert_metrics, **chexpert_metrics}
    all_result = pd.DataFrame(all_result).T.to_csv(os.path.join(result_dir, "all_result.csv"))
    with open(os.path.join(result_dir, "result.json"), 'w') as f:
        json.dump(results, f, indent=2)
    print(f"{len(preds)} data processed")
    return results



def process_annotation(result_dir, annotation):
    result_dir_name = Path(result_dir).stem
    task = None
    if "comparison" in result_dir_name:
        return None
    elif "correction" in result_dir_name:
        task = "correction"
    elif "template" in result_dir_name:
        task = "template"
    elif "history" in result_dir_name:
        task = "history"
    else:
        return None
    valid_ids = []
    df = pd.read_csv(annotation)
    mask = df['record_path'].apply(lambda x: task in os.path.basename(x))
    df = df[mask]
    df = df[df["A1"] == "yes"]
    for item in df["record_path"]:
        study_id = Path(item).stem.split('\\')[-1]
        assert study_id[0] == 's' and len(study_id) == 9
        valid_ids.append(study_id)
    return valid_ids




if __name__ == "__main__":
    set_start_method("spawn")
    args = parse_args()
    start = time.time()
    valid_ids = process_annotation(args.result_dir, args.annotation)
    eval_result_dir(args.result_dir, valid_ids)
    print(time.time()-start)