import logging
import argparse
from argparse import Namespace
from tqdm import tqdm

from src.state_handlers import STATE_HANDLERS

from .datasets.hotpot_qa.hotpot_evaluate_v1 import get_init_metrics

from .constants import *
from .utils import *
from .agent import LLMAgent

logger = logging.getLogger(LOGGER_NAME)

def main(args : Namespace):
    config = setup_experiment_files(args)

    dataset = load_data(args)

    agent = LLMAgent(config)

    metrics = get_init_metrics() if args.dataset.startswith('hotpot') else { 'correct' : 0 }

    prior_predictions = {}
    if os.path.exists(config.predictions_file):
        with open(config.predictions_file, 'r') as f:
            content = json.loads('[' + str(f.read()).replace('}\n{', '},\n{') + ']')
        for ex in content: 
            if ex: prior_predictions[ex['id']] = ex

    for input_text, answer, idx in tqdm(dataset, desc='evaluation'):
        if idx in prior_predictions:
            prediction = prior_predictions[idx]['generated_text']
        else:
            # for reflexion specifically
            if 'reflexion' in config.agent_type:
                for k in agent.state_handlers:
                    if 'evaluator' in k: agent.state_handlers[k].update_answer(answer)
            agent_output = agent.predict(input_text)
            prediction = agent_output.text

        is_correct, extracted_answer = check_correctness(prediction, answer, metrics, config)

        if idx not in prior_predictions:
            
            json_data = {
                "id": idx,
                "input_text": input_text,
                "generated_text": prediction,
                "prediction": extracted_answer,
                "answer": answer,
                "is_correct": is_correct,
            }

            write_json(json_data, config.predictions_file)

    with open(config.results_file, 'w') as rf:
        keys = list(metrics)
        wr_keys, wr_vals = ['total'] + keys, [str(len(dataset))] + [str(metrics[k]) for k in keys]
        rf.write('\t'.join(wr_keys) + '\n')
        rf.write('\t'.join(wr_vals) + '\n')

if __name__ == '__main__':

    """
    How to call:
        python -m src.evaluate --model_type <model-type> --dataset <dataset-name>

    Example:
        python -m src.evaluate --model_type t5-small --agent_type react --dataset fever_v1.0 --dataset_size 5 --restart --debug
        python -m src.evaluate --model_type tiiuae/falcon-40b --agent_type react --dataset fever_v1.0 --dataset_size 5 --restart --debug

        # Fever experiment
        python -m src.evaluate --model_type mosaicml/mpt-7b-instruct --agent_type react --dataset fever_v1.0 --dataset_size 100 --split paper_dev --restart --debug
        python -m src.evaluate --model_type google/flan-t5-xxl --agent_type react --dataset fever_v1.0 --dataset_size 100 --split paper_dev --restart --debug
        python -m src.evaluate --model_type tiiuae/falcon-40b --agent_type react --dataset fever_v1.0 --split paper_dev --restart --debug

        # Hotpot experiment
        python -m src.evaluate --model_type mosaicml/mpt-7b-instruct --agent_type react --dataset hotpot_qa --dataset_size 100 --split dev --restart --debug
        python -m src.evaluate --model_type mosaicml/mpt-7b-instruct --agent_type react --dataset fever_v1.0 --dataset_size 100 --split paper_dev --restart --debug
        python -m src.evaluate --model_type mosaicml/mpt-7b-instruct --agent_type reflexion --dataset hotpot_qa --dataset_size 10 --split dev --restart --debug
    """

    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", type=str, required=True, help="Model type to use")
    parser.add_argument("--agent_type", type=str, required=True, help="Agent type to use")
    parser.add_argument("--dataset", type=str, required=True, choices=['gsm8k_main', 'gsm8k_socratic', 'fever_v1.0', 'hotpot_qa'], help=" Name of dataset you are training")
    parser.add_argument("--model_name", default=None, help="Save model with name")
    parser.add_argument("--split", default="train", help="Split to test prompting on")
    parser.add_argument("--dataset_size", default=-1, type=int, help="Dataset size to evaluate on")
    parser.add_argument("--few_shot_k", default=1, type=int, help="K few-shot examples")
    parser.add_argument("--dataset_range", default=None, type=str, help="Dataset size to evaluate on")

    parser.add_argument("--restart", action="store_true", help="Restart experiment even if partial results exist")
    parser.add_argument("--debug", action="store_true", help="Debug mode")
    
    # hyperparams
    parser.add_argument("--max_tokens", type=int, default=1024, help="Maximum number of tokens to generate")
    parser.add_argument("--max_state_tokens", type=int, default=50, help="Maximum number of tokens to generate in a state")
    parser.add_argument("--temperature", type=float, default=-1, help="Temperature for decoding")

    args = parser.parse_args()
    
    assert args.dataset_range is None or (len(args.dataset_range.split('-')) == 2 and [int(x) for x in args.dataset_range.split('-')]), \
        f'Malformed dataset_range argument {args.dataset_range}'

    # for fairness, we give this more tokens
    if 'ablation' in args.agent_type: args.max_state_tokens += 50

    # set up logging
    logger.setLevel(logging.DEBUG if args.debug else logging.INFO)
    log_handler = logging.StreamHandler()
    log_formatter = logging.Formatter(fmt=' %(name)s :: %(levelname)s :: %(message)s')
    log_handler.setFormatter(log_formatter)
    logger.addHandler(log_handler)

    if args.model_name is None: args.model_name = canon_model_name(args.model_type)

    main(args)
