import argparse
import numpy as np
import pandas as pd
import policy
import random
import os
import json
from config import *
from create_tensor import load_metadata
from tqdm import tqdm
from utils import load_val_predictions, load_test_predictions


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset', type=str, required=True)
    parser.add_argument('--model_id', type=str, required=True)
    parser.add_argument('--prediction', type=str, required=True)
    parser.add_argument('--epoch', type=int, required=True)
    
    parser.add_argument('--num_chains', type=int, required=True)
    parser.add_argument('--threshold_val_pct', type=int, default=90)
    parser.add_argument('--warm_up', type=int, default=2)
    parser.add_argument('--patience_high', type=int, default=2)
    parser.add_argument('--interval', type=int, default=16)
    parser.add_argument('--consensus_frac', type=float, default=0.6)
    parser.add_argument('--voting_frac', type=float, default=1)
    parser.add_argument('--sample_lambda', type=float, default=1)

    parser.add_argument('--m_frac', type=float, default=0.6)
    
    parser.add_argument('--resample_chains', type=bool, default=False)
    parser.add_argument('--record_dir', type=str, required=False, default=None)
    
    args = parser.parse_args()

    metadata = load_metadata(args.dataset, args.model_id, for_training=False)
    metadata_test = metadata[metadata['train'] == 0]

    val_preds = load_val_predictions(
        os.path.join(
            DSET_TO_DIR[args.dataset],
            MODEL_IDS[args.model_id]
        ),
        args.prediction,
        args.epoch
    )
    test_preds = load_test_predictions(
        os.path.join(
            DSET_TO_DIR[args.dataset],
            MODEL_IDS[args.model_id]
        ),
        args.prediction,
        args.epoch
    )
    default_pred = val_preds.pred.median()
    threshold = np.percentile(val_preds.pred.values, args.threshold_val_pct)
    metadata_test_wpred = metadata_test.merge(
        test_preds, on=['unique_id', 'chain_id', 'tokens'], how='left')
    metadata_test_wpred['pred'] = metadata_test_wpred['pred'].fillna(default_pred)

    if args.record_dir is not None:
        record_dir_full = os.path.join(
            DSET_TO_DIR[args.dataset],
            MODEL_IDS[args.model_id],
            "latency_traces",
            args.record_dir
        )
    else:
        record_dir_full = None
    
    baseline_res = policy.baseline(
        metadata_test_wpred,
        args.num_chains,
        args.resample_chains
    )

    shortm_res = policy.short_m(
        metadata_test_wpred,
        args.num_chains,
        args.m_frac,
        args.resample_chains
    )

    res = policy.duchess(
        metadata_test_wpred,
        args.num_chains,
        threshold,
        -1,
        args.warm_up,
        args.patience_high,
        1,
        args.interval,
        args.consensus_frac,
        args.voting_frac,
        'greedy_prob',
        True,
        args.resample_chains,
        record_dir_full,
        args.sample_lambda
    )

    print("--------------------------------")
    print("Default SC accuracy: ", baseline_res['accuracy'].mean())
    print("Short-m@k accuracy: ", shortm_res['accuracy'].mean())
    print("DUCHESS accuracy: ", res['accuracy'].mean())


if __name__ == "__main__":
    main()
