import json
import os

from metrics.skr.grailqa.evaluator import evaluate as grailqa_evaluate
from metrics.skr.mtop.evaluator import evaluate as mtop_evaluate
from metrics.skr.spider.evaluator import evaluate as spider_evaluate
from metrics.skr.compwebq.evaluator import evaluate as compwebq_evaluate
from metrics.commonsense.evaluate import evaluate as commonsense_evaluate

spider_test_data = json.load(open('data/cl/spider_test.json', 'r'))
spider_test_db_ids = [ex['meta']['db_id'] for ex in spider_test_data]
compwebq_test_data = json.load(open('data/cl/compwebq_test.json', 'r'))
compwebq_test_golds = [ex['meta']['gold_res'] for ex in compwebq_test_data]

skr_order = {
    "1": ["grailqa", "mtop", "spider", "compwebq"],
    "2": ["compwebq", "spider", "mtop", "grailqa"],
    "3": ["mtop", "grailqa", "compwebq", "spider"],
}

mix_order = {
    "1": ["grailqa", 'boolq', "mtop", 'piqa', "spider", 'arce', "compwebq", 'winogrande'],
    "2": ["compwebq", 'winogrande', "spider", 'arce', "mtop", 'piqa', "grailqa", 'boolq'],
    "3": ["mtop", 'piqa', "grailqa", 'boolq', "compwebq", 'winogrande' "spider", 'arce'],
}


def evaluate(task_name, preds, golds):
    if task_name == "grailqa":
        return grailqa_evaluate(preds, golds)
    elif task_name == "mtop":
        return mtop_evaluate(preds, golds)
    elif task_name == "spider":
        return spider_evaluate(preds, golds, spider_test_db_ids)
    elif task_name == "compwebq":
        return compwebq_evaluate(preds, compwebq_test_golds)
    else:
        return commonsense_evaluate(preds, golds, task_name)


def evaluate_single_task(method, model, stream, order_id, task_id, eval_task_id):
    print(f"Evaluating {method} {model} {stream} {order_id} {task_id} {eval_task_id}")

    if os.path.exists(
        f'ckpt/cl/{method}/{model}/{stream}_{order_id}/eval_results/{task_id}/{eval_task_id}/eval_results.json'
    ):
        return json.load(
            open(
                f'ckpt/cl/{method}/{model}/{stream}_{order_id}/eval_results/{task_id}/{eval_task_id}/eval_results.json',
                'r',
            )
        )['accuracy']

    if stream == "skr":
        task_name = skr_order[str(order_id)][eval_task_id - 1]
    elif stream == "mix":
        task_name = mix_order[str(order_id)][eval_task_id - 1]
    else:
        raise ValueError(f"Invalid stream: {stream}")

    print(f"Task name: {task_name}")

    preds, golds = [], []
    results_file = f'ckpt/cl/{method}/{model}/{stream}_{order_id}/eval_results/{task_id}/{eval_task_id}/generated_predictions.jsonl'
    with open(results_file, 'r') as f:
        for line in f:
            example = json.loads(line)
            preds.append(example['predict'])
            golds.append(example['label'])

    acc = evaluate(task_name, preds, golds)
    json.dump(
        {'total_num': len(preds), 'accuracy': acc},
        open(
            f'ckpt/cl/{method}/{model}/{stream}_{order_id}/eval_results/{task_id}/{eval_task_id}/eval_results.json', 'w'
        ),
        indent=2,
        ensure_ascii=False,
    )
    return acc


def get_acc_by_task_id(method, model, stream, order_id, task_id, eval_task_id):
    acc = json.load(
        open(
            f'ckpt/cl/{method}/{model}/{stream}_{order_id}/eval_results/{task_id}/{eval_task_id}/eval_results.json', 'r'
        )
    )['accuracy']

    return float(acc)


def compute_cl_metrics(method, model, stream, order_id):
    metrics = {}
    if stream == "skr":
        for task_id in [1, 2, 3, 4]:
            metrics[task_id] = {}
            acc_sum = 0
            for eval_task_id in range(1, task_id + 1):
                print(f"Task id: {task_id}, Eval task id: {eval_task_id}")
                print(get_acc_by_task_id(method, model, stream, order_id, task_id, eval_task_id))
                acc_sum += get_acc_by_task_id(method, model, stream, order_id, task_id, eval_task_id)
            metrics[task_id]['acc'] = acc_sum / task_id
            if task_id > 1:
                bwt_sum = 0
                for eval_task_id in range(1, task_id):
                    cur_acc = get_acc_by_task_id(method, model, stream, order_id, task_id, eval_task_id)
                    ori_acc = get_acc_by_task_id(method, model, stream, order_id, eval_task_id, eval_task_id)
                    bwt_sum += cur_acc - ori_acc
                metrics[task_id]['bwt'] = bwt_sum / (task_id - 1)
    elif stream == "mix":
        for task_id in [1, 2, 3, 4, 5, 6, 7, 8]:
            metrics[task_id] = {}
            acc_sum = 0
            for eval_task_id in range(1, task_id + 1):
                acc_sum += get_acc_by_task_id(method, model, stream, order_id, task_id, eval_task_id)
            metrics[task_id]['acc'] = acc_sum / task_id
            if task_id > 1:
                bwt_sum = 0
                for eval_task_id in range(1, task_id):
                    cur_acc = get_acc_by_task_id(method, model, stream, order_id, task_id, eval_task_id)
                    ori_acc = get_acc_by_task_id(method, model, stream, order_id, eval_task_id, eval_task_id)
                    bwt_sum += cur_acc - ori_acc
                metrics[task_id]['bwt'] = bwt_sum / (task_id - 1)

    json.dump(metrics, open(f'ckpt/cl/{method}/{model}/{stream}_{order_id}/cl_metrics.json', 'w'), indent=2, ensure_ascii=False)
    return metrics


if __name__ == "__main__":
    method = "mem"

    # for order_id in [1, 2, 3]:
    #     for model in ["llama3.1", "qwen"]:
    #         for stream in ["skr"]:
    #             compute_cl_metrics(method, model, stream, order_id)

    stream = "skr"
    for model in ["llama3.1", "qwen"]:
        for order_id in [1, 2, 3]:
            for task_id in [1, 2, 3, 4]:
                if task_id == 4:
                    eval_task_id_max = 4
                else:
                    eval_task_id_max = task_id + 1
                for eval_task_id in range(1, eval_task_id_max + 1):
                    print(evaluate_single_task(method, model, stream, order_id, task_id, eval_task_id))

            compute_cl_metrics(method, model, stream, order_id)


    # stream = "mix"
    # for model in ["llama3.1", "qwen"]:
    #     for order_id in [1]:
    #         for task_id in [1, 2, 3, 4, 5,6,7,8]:
    #             if task_id == 8:
    #                 eval_task_id_max = 8
    #             else:
    #                 eval_task_id_max = task_id + 1
    #             for eval_task_id in range(1, eval_task_id_max + 1):
    #                 print(evaluate_single_task(method, model, stream, order_id, task_id, eval_task_id))
