import json
import argparse
import numpy as np
from decimal import Decimal, ROUND_HALF_UP

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(x) / np.sum(np.exp(x), axis=0)

def compute_acc(
    seed: int = 0,
    dataset: str = 'bmlama',
    instance_num: int = 5000,
    mname: str | None = None,
) -> float:
    """Compute consistency between two languages."""

    train_data = f"seed{seed}_sample{instance_num}_{dataset}_baseline"
    post_mname = f"{mname.replace('/', '-')}"

    lang_en = ['en']
    if dataset == 'bmlama':
        langs = ['fr', 'nl', 'es', 'ru', 'ja', 'zh', 'ko', 'vi', 'el', 'hu', 'he', 'tr', 'ca', 'ar', 'uk', 'fa']
    elif dataset == 'mmmlu':
        langs = ['ar', 'de', 'es', 'fr', 'hi', 'id', 'it', 'ja', 'ko', 'pt', 'sw', 'yo', 'zh', 'bn']
        langs = ['ar', 'de', 'es', 'fr', 'hi', 'id', 'it', 'ja', 'ko', 'pt', 'zh']
    elif dataset == 'xcsqa':
        langs = ['zh', 'de', 'es', 'fr', 'it', 'ja', 'nl', 'pl', 'pt', 'ru', 'ar', 'vi', 'hi', 'sw', 'ur']
    else:
        raise ValueError(f"Unknown dataset: {dataset}")
    
    try:
        acc_en = json.load(open(f'./outputs/{train_data}/{post_mname}/en_Accuracy.json', 'r'))
        acc_en = round(acc_en*100, 2)
    except Exception as e:
        acc_en = -100.00  # Default value if error occurs

    acc_list = []
    for lang in langs:
        try:
            acc = json.load(open(f'./outputs/{train_data}/{post_mname}/{lang}_Accuracy.json', 'r'))
            acc = round(acc*100, 2)
            acc_list.append(acc)
        except Exception as e:
            # print(f"Error occurred for {lang}: {e}")
            acc_list.append(-100.00)  # Append -1 for languages that fail to compute consistency

    acc_list = [acc_en] + acc_list
    print(f"\\modelname{{{mname}}}" + " & " + " & ".join([f"{c:.2f}" if c != -100.00 else "Fail" for c in acc_list]) + r" \\")
    return acc_list

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=0, help='random seed for data generation')
    parser.add_argument('--dataset', type=str, default='bmlama', help='dataset name')
    parser.add_argument('--instance_num', type=int, default=5000, help='number of instances')
    parser.add_argument('--mname', type=str, default='meta-llama/Llama-3.2-3B', help='model name')

    args = parser.parse_args()
    seed = args.seed
    dataset = args.dataset
    instance_num = args.instance_num
    mname = args.mname
    
    CLC = compute_acc(
        seed=seed, dataset=dataset, 
        instance_num=instance_num,
        mname=mname,
    )
