import numpy as np
from utils.orchestrator import Evaluator


# parameters
data_name = 'recl'      # recl, rwth, rtraffic, rillness, rett
n_repeats = 20          # number of samples for the filtering
batch_size = 4096       # batch size to use to run through the entire dataset
model = 'dm'            # dm, moment, timer


# check model
allowed_model = ['dm', 'moment', 'timer']
model_not_implemented_error = f'model={model} not in {allowed_model}.'
assert model in allowed_model, model_not_implemented_error

# init evaluator
expt_name = data_name
evaluator = Evaluator(expt_name)

# loop to compute filtered relation error divided by mean relation error for each prompt
filtered_re_div_mean_re = {}
for task in evaluator.allowed_tasks:

    # get responses for the same prompt 'n_repeats' times
    combined_errors, relation_errors = [], []
    for _ in range(n_repeats):
        if model == 'dm':
            ce, re, _ = evaluator.compute_ce_and_re_dm(task, batch_size)
        elif model == 'moment':
            context_len = 24
            ce, re, _ = evaluator.compute_ce_and_re_moment(task, batch_size, context_len, deterministic=False)
        elif model == 'timer':
            ce, re, _ = evaluator.compute_ce_and_re_timer(task, batch_size, deterministic=False)
        else:
            raise NotImplementedError(model_not_implemented_error)

        # format combined errors and relation errors
        combined_errors.append(ce)
        relation_errors.append(re)
    combined_errors = np.array(combined_errors).T   # shape (-1, n_repeats)
    relation_errors = np.array(relation_errors).T   # shape (-1, n_repeats)

    # find corresponding relation error for argmin of combined error for each prompt
    filtered_relation_error = relation_errors[np.arange(len(combined_errors)), np.argmin(combined_errors, axis=1)]

    # compute filtered relation error divided by mean relation error for each prompt or by deterministic version for FMs
    if model == 'dm':
        filtered_re_div_mean_re[task] = filtered_relation_error / relation_errors.mean(axis=1)
    elif model == 'moment':
        context_len = 24
        _, re_deterministic, _ = evaluator.compute_ce_and_re_moment(task, batch_size, context_len, deterministic=True)
        filtered_re_div_mean_re[task] = filtered_relation_error / re_deterministic
    elif model == 'timer':
        _, re_deterministic, _ = evaluator.compute_ce_and_re_timer(task, batch_size, deterministic=True)
        filtered_re_div_mean_re[task] = filtered_relation_error / re_deterministic
    else:
        raise NotImplementedError(model_not_implemented_error)


# print
print(f'Model={model}, Dataset={data_name}, Filtered RE / Mean RE:')
for task in evaluator.allowed_tasks:
    print(f'\nTask={task}: {filtered_re_div_mean_re[task].mean():.4g} +- {filtered_re_div_mean_re[task].std():.4g}')
