import datetime
import sys
from tqdm import tqdm

from bbh_loader import BBHDataLoader
from decode_bbh import *
from config_bbh import *
from extract_CoT_bbh import *
from tools_bbh import *
import os
import json

# Import BLEU and ROUGE related functions
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer

# Set up smoothing function (for BLEU calculation)
smooth_fn = SmoothingFunction().method1
# Create ROUGE scorer (only ROUGE-1 is calculated here, other metrics can be added as needed)
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)

def process_single_example_k_smethod_bmethod(example, model, tokenizer, k_values, sample_methods, branch_methods, measure_methods, lowest_probs,task_type):
    """
    For a single example, iterate through the following dimensions:
      - k_values: (1,3,5,10 ...)
      - sample_methods: ("sequential","fibonacci", ...)
      - branch_methods: ("first_token","greedy_highest","greedy_lowest", ...)
      - lowest_probs: (e.g., [0.2, 0.3] etc.)
      - measure_methods: ("logits","gap","gap_ratio", ...)
    Perform decoding and strategy decision, return structure:
    {
      "question": q,
      "ground_truth": ground_truth,
      "results_k_smethod_bm": {
         k_val: {
           sm: {
             bm: {
               lp: {
                 mm: {
                   "top1_total": { ... },
                   "top1_follow":{ ... },
                   ...
                 }
               }
             }
           }
         }
      }
    }
    """
    # Co-classification task
    # q = "Task description: identify a broad class given several examples from that class\n" + "Question: " + example["question"]
    # Coreference resolution task
    # q = "Task description: An indirect anaphora resolution task that is cast as a context dependent question answering problem\n" + "Context: " + example["question"]
    # Sentence rephrasing task
    # q = "Task description: Rephrase the sentences below so that they retain their meaning but contain the specified keyword.\n" + "Sentence: " + \
    #     example["question"]
    # Disfluent text QA
    q = example["question"]

    ground_truth = example["answer"]

    ins = task_instructions[task_type]
    results_k_smethod_bm = {}
    for k_val in k_values:
        results_k_smethod_bm[k_val] = {}
        for sm in sample_methods:
            results_k_smethod_bm[k_val][sm] = {}
            for bm in branch_methods:
                results_k_smethod_bm[k_val][sm][bm] = {}
                for lp in lowest_probs:
                    results_k_smethod_bm[k_val][sm][bm][lp] = {}
                    for mm in measure_methods:
                        # Call the decoding function with lowest_prob and branch_method parameters
                        k_response, k_gen_probs, k_follow_probs = get_token_k_j_path_prob_follow_up(
                            model, tokenizer,
                            query=TEMPLATE.format(question=q,instruction=ins),
                            follow_up_template=FOLLOW_TEMPLATE,
                            k=k_val,
                            j=2,
                            measure_method=mm,
                            sample_method=sm,
                            lowest_prob=lp,
                            branch_method=bm
                        )

                        # Remove the question part
                        k_response = [response.replace(TEMPLATE.format(question=q,instruction=ins), '') for response in k_response]

                        # Split follow_texts
                        gen_texts = [
                            resp.split(FOLLOW_TEMPLATE)[0] if FOLLOW_TEMPLATE in resp else ""
                            for resp in k_response
                        ]

                        follow_texts = [
                            resp.split(FOLLOW_TEMPLATE)[1] if FOLLOW_TEMPLATE in resp else ""
                            for resp in k_response
                        ]
                        merged = [gen + follow for gen, follow in zip(k_gen_probs, k_follow_probs)]

                        # Strategy invocation

                        r_top1_follow = select_top1_path(k_follow_probs,  follow_texts, ground_truth[0],task_type=task_type)

                        r_max_follow  = select_max_path(k_follow_probs,   follow_texts, ground_truth[0],task_type=task_type)

                        r_agg_follow  = select_open_aggregated_path(k_follow_probs,  follow_texts, ground_truth[0],task_type=task_type)

                        r_agg_new = select_open_aggregated_path_new(k_gen_probs, k_follow_probs, follow_texts, ground_truth[0],task_type=task_type)

                        results_k_smethod_bm[k_val][sm][bm][lp][mm] = {
                            "top1_follow": r_top1_follow,
                            "max_follow": r_max_follow,
                            "agg_follow": r_agg_follow,
                            "agg_new": r_agg_new,
                        }
    return {
        "question": q,
        "ground_truth": ground_truth,
        "results_k_smethod_bm": results_k_smethod_bm
    }


def accumulate_stats_for_k_smethod_bmethod(example_result, stats_dict):
    """
    For multiple groups of (k, sample_method, branch_method, lowest_prob, measure_method) results for a single example,
    calculate BLEU and ROUGE-1 metrics between each strategy-generated answer and the reference answer, and accumulate them in stats_dict.
    Additionally, if the candidate answer contains any element from the reference answer list, the match metric is incremented by 1.
    Note: Here ground_truth is a list, and the metrics are calculated for each reference answer in the list, with the highest value taken as the result.
    """
    # ground_truth is a list, first lowercase and strip each reference answer
    refs = [gt.lower().strip() for gt in example_result["ground_truth"]]
    results_k_smethod_bm = example_result["results_k_smethod_bm"]

    for k_val, sm_dict in results_k_smethod_bm.items():
        for sm, bm_dict in sm_dict.items():
            for bm, lp_dict in bm_dict.items():
                for lp, mm_dict in lp_dict.items():
                    for mm, strat_dict in mm_dict.items():
                        # strat_dict may contain multiple strategy results
                        for key_name, result_obj in strat_dict.items():
                            if result_obj:
                                hyp = result_obj["answer"].lower().strip()
                                bleu_scores = []
                                rouge_scores = []
                                # For each reference answer in the ground_truth list, calculate BLEU and ROUGE-1 scores
                                for ref in refs:
                                    bleu_val = sentence_bleu([ref.split()], hyp.split(), smoothing_function=smooth_fn)
                                    bleu_scores.append(bleu_val)
                                    rouge_val = scorer.score(ref, hyp)['rouge1'].fmeasure
                                    rouge_scores.append(rouge_val)
                                # Take the highest BLEU and ROUGE scores from all reference answers
                                bleu = max(bleu_scores) if bleu_scores else 0.0
                                rouge1 = max(rouge_scores) if rouge_scores else 0.0

                                # Calculate match metric: if any element in the answer list appears in the candidate answer, match = 1; otherwise 0
                                match_value = 1 if any(ref in hyp for ref in refs) else 0

                                # Accumulate metrics into the statistics dictionary
                                stats_dict[k_val][sm][bm][lp][mm][key_name] += bleu
                                stats_dict[k_val][sm][bm][lp][mm][key_name + "_rouge1"] += rouge1
                                stats_dict[k_val][sm][bm][lp][mm][key_name + "_match"] += match_value


#############################################
# Modified summary function: Convert cumulative scores to average metrics
#############################################
def build_overall_results_for_k_smethod_bmethod(stats_dict, total_examples):
    """
    Convert cumulative BLEU and ROUGE scores to averages (i.e., average BLEU and ROUGE metrics for each strategy),
    return overall_results[k][sm][bm][lp][mm] as a dictionary containing all strategy metrics.
    """
    overall_results = {}
    for k_val, sm_dict in stats_dict.items():
        overall_results[k_val] = {}
        for sm, bm_dict in sm_dict.items():
            overall_results[k_val][sm] = {}
            for bm, lp_dict in bm_dict.items():
                overall_results[k_val][sm][bm] = {}
                for lp, mm_dict in lp_dict.items():
                    overall_results[k_val][sm][bm][lp] = {}
                    for mm, metric_dict in mm_dict.items():
                        d = {}
                        for key_name, total_score in metric_dict.items():
                            d[key_name] = total_score / total_examples
                        overall_results[k_val][sm][bm][lp][mm] = d
    return overall_results


#############################################
# The rest of the code remains unchanged, only the statistics, accumulation, and final output metrics are interpreted as BLEU/ROUGE averages
#############################################
def print_interim_results(overall_results, example_count):
    """Print interim results (simplified version), append results to bbh_result.txt whenever example_count is divisible by 20"""
    write_to_file = (example_count % 20 == 0)
    if write_to_file:
        # Open file and write example_count
        file_name = f"bbh_result_cot_is_{cot}_model_is_{RUNNING_MODEL}_{shots}shot_gpu_is_{GPU_ID}.txt"
        with open(file_name, "a", encoding="utf-8") as f:
            f.write(f"example_count: {example_count}\n")

    for k_val, sm_dict in overall_results.items():
        header = f"\nk={k_val}:"
        print(header)
        if write_to_file:
            with open(file_name, "a", encoding="utf-8") as f:
                f.write(header + "\n")

        for sm, bm_dict in sm_dict.items():
            line_sm = f"  {sm}:"
            print(line_sm)
            if write_to_file:
                with open(file_name, "a", encoding="utf-8") as f:
                    f.write(line_sm + "\n")

            for bm, lp_dict in bm_dict.items():
                line_bm = f"    {bm}:"
                print(line_bm)
                if write_to_file:
                    with open(file_name, "a", encoding="utf-8") as f:
                        f.write(line_bm + "\n")

                for lp, mm_dict in lp_dict.items():
                    line_lp = f"      p>{lp}:"
                    print(line_lp)
                    if write_to_file:
                        with open(file_name, "a", encoding="utf-8") as f:
                            f.write(line_lp + "\n")

                    for mm, d in mm_dict.items():
                        line_mm = f"        {mm}:"
                        print(line_mm)
                        if write_to_file:
                            with open(file_name, "a", encoding="utf-8") as f:
                                f.write(line_mm + "\n")

                        # Print and (optionally) write each metric
                        metrics = [
                            ("top1_follow",    "top1_follow_rouge1",    "top1_follow_match",    "top1_follow"),
                            ("max_follow",     "max_follow_rouge1",     "max_follow_match",     "max_follow"),
                            ("agg_follow",     "agg_follow_rouge1",     "agg_follow_match",     "agg_follow"),
                            ("agg_new",        "agg_new_rouge1",        "agg_new_match",        "agg_new"),
                        ]
                        for key_base, key_r1, key_match, key_label in metrics:
                            text = (
                                f"          {key_base} BLEU: {d[key_base]:.1%}, "
                                f"ROUGE-1: {d[key_r1]:.1%}, "
                                f"MATCH: {d[key_match]:.1%}"
                            )
                            print(text)
                            if write_to_file:
                                with open(file_name, "a", encoding="utf-8") as f:
                                    f.write(text + "\n")

    # If the file was written this time, add an empty line as a separator
    if write_to_file:
        with open(file_name, "a", encoding="utf-8") as f:
            f.write("\n")


def print_final_results(overall_results):
    """Print complete final results"""
    for k_val, sm_dict in overall_results.items():
        print(f"\n=== k={k_val} ===")
        for sm, bm_dict in sm_dict.items():
            print(f"  sample_method={sm}")
            for bm, lp_dict in bm_dict.items():
                print(f"    branch_method={bm}")
                for lp, mm_dict in lp_dict.items():
                    print(f"      lowest_prob={lp}")
                    for mm, d in mm_dict.items():
                        print(f"        measure_method={mm}")
                        print(f"          top1_follow BLEU: {d['top1_follow']:.3f}, ROUGE-1: {d['top1_follow_rouge1']:.3f}, MATCH: {d['top1_follow_match']:.3f}")
                        print(f"          max_follow BLEU: {d['max_follow']:.3f}, ROUGE-1: {d['max_follow_rouge1']:.3f}, MATCH: {d['max_follow_match']:.3f}")
                        print(f"          agg_follow BLEU: {d['agg_follow']:.3f}, ROUGE-1: {d['agg_follow_rouge1']:.3f}, MATCH: {d['agg_follow_match']:.3f}")
                        print(f"          agg_new BLEU: {d['agg_new']:.3f}, ROUGE-1: {d['agg_new_rouge1']:.3f}, MATCH: {d['agg_new_match']:.3f}")


#############################################
# The following is the main program entry, unchanged, only calling the modified functions above
#############################################
if __name__=="__main__":
    filter_token_nums_gen = config_base.get('filter_token_gen')
    filter_token_nums_follow = config_base.get('filter_token_follow')
    measure_methods = config_base.get('measure_methods')
    sample_methods = config_base.get('sample_methods')
    branch_methods = config_base.get('branch_methods')
    lowest_probs = config_base.get('lowest_probs')
    k_values = config_base.get('k_values')
    bbh_loader = BBHDataLoader()
    # Get JSON file name list (without extension)
    task_names = bbh_loader.get_json_file_names()
    i = 1
    for current_task in RUNNING_TASK:
        task_data = bbh_loader.get_by_json_name(current_task)
        task_type = bbh_loader.task_type_mapping.get(current_task, '')
        print(f'''=======\nCurrent task: {current_task}, task type: {task_type}\nCurrent task example input: {task_data["json"].get('examples', '')[0]['input']}\nCurrent task example answer: {task_data["json"].get('examples', '')[0]['target']}''')
        i += 1
        test_examples = bbh_loader.get_by_json_name(current_task)['json']['examples']
        for ex in test_examples:
            ex.update(question=ex["input"])
            ex.update(answer=ex["target"])

        total_examples = 0

        # Initialize statistics dictionary (cumulative metrics), including keys for BLEU and ROUGE
        stats_dict = {}
        for k_val in k_values:
            stats_dict[k_val] = {}
            for sm in sample_methods:
                stats_dict[k_val][sm] = {}
                for bm in branch_methods:
                    stats_dict[k_val][sm][bm] = {}
                    for lp in lowest_probs:
                        stats_dict[k_val][sm][bm][lp] = {}
                        for mm in measure_methods:
                            stats_dict[k_val][sm][bm][lp][mm] = {
                                "top1_follow": 0, "top1_follow_rouge1": 0,"top1_follow_match": 0,
                                "max_follow": 0, "max_follow_rouge1": 0,"max_follow_match": 0,
                                "agg_follow": 0, "agg_follow_rouge1": 0,"agg_follow_match": 0,
                                "agg_new": 0, "agg_new_rouge1": 0,"agg_new_match": 0,
                            }
        print_interval = 1
        example_count = 0
        results_list = []
        for example in tqdm(test_examples[:1000], desc="Processing examples"):
            total_examples += 1
            example_count += 1

            # Process a single example
            example_result = process_single_example_k_smethod_bmethod(
                example, model, tokenizer,
                k_values, sample_methods, branch_methods, measure_methods, lowest_probs, task_type
            )
            results_list.append(example_result)
            accumulate_stats_for_k_smethod_bmethod(example_result, stats_dict)

            # Print interim results every 10 examples
            if example_count % print_interval == 0:
                print(f"\n=== Processed {example_count} examples, current statistics ===")
                temp_overall = build_overall_results_for_k_smethod_bmethod(stats_dict, example_count)
                print_interim_results(temp_overall,example_count)

        overall_results = build_overall_results_for_k_smethod_bmethod(stats_dict, total_examples)

        # Write results to file
        output_data = {
            "overall_results": overall_results,
            "detailed_results": results_list
        }
        os.makedirs(f"result/bbh/{RUNNING_MODEL}/", exist_ok=True)
        current_time = datetime.datetime.now().strftime("%m-%d_%H-%M-%S")
        out_path = f"result/bbh/{RUNNING_MODEL}/RESULT_BASE_{current_task}_{task_type}_{current_time}.json"
        with open(out_path, "w", encoding="utf-8") as f:
            json.dump(output_data, f, indent=2)

        print("\n=== Final results ===")
        print_final_results(overall_results)
