import os
import json
import argparse
import requests
from tqdm import tqdm
from functools import partial
from copy import deepcopy
from models import NLUModel, NLGModel, POLModel


def read_cli():
    """
    Read the command line arguments.
    - cfg: config path
    - mp: model_path
    - bu: base_url
    """
    parser = argparse.ArgumentParser(description='Run the model.')
    parser.add_argument('-ids', '--ids_path', type=str, help='Ids path.', required=True)
    parser.add_argument('-tar', '--tar_path', type=str, help='Tar path.', required=True)
    parser.add_argument('-pol_model', '--pol_model', type=str, help='POL MODEL', required=True, choices=["llama3", "bayesnet", "rag"])
    parser.add_argument('-server_url', '--server_url', type=str, help='SERVER URL', required=True)

    return parser.parse_args()


def save_json(obj, path):
    with open(path, "w") as fp:
        json.dump(obj, fp, indent=2)


def run_single(did, server_url, nlu_client, nlg_client, pol_client, tar_path):
    print()
    print(f"Running for {did}")
    print()

    log_path = os.path.join(tar_path, f"{did}.json")
    if os.path.exists(log_path):
        logs = json.load(open(log_path, 'r'))
        if "status" in logs and logs['status'] not in ['action_generation_error']:
            print(f"Already processed {did}. Skipping.")
            return

    logs = {}
    logs['did'] = did

    # 1. Create a Session
    url = f"{server_url}/patient/create_session"
    payload = {"did": did, "mode": "actions"}
    response = requests.post(url, json=payload)
    response.raise_for_status()
    obj = response.json()
    logs['session_id'] = obj['session_id']
    session_id = logs['session_id']

    # 2. Get the initial dialog
    ret = requests.post(f"{server_url}/patient/{session_id}/get_dialog")
    ret.raise_for_status()
    dialog_history = ret.json()['dialog']
    dialog_state = dict()

    turn_id = 0
    status = "successful"
    logs['utterances'] = []
    last_action = []
    patient_uttr = dialog_history[-1].split(":", 1)[-1].strip()
    while status != 'terminated':
        nlu, dialog_state = nlu_client.predict(dialog_history=dialog_history, dialog_state=dialog_state, last_action=last_action)
        logs['utterances'].append({
            "turn_id": turn_id,
            "text": patient_uttr,
            "speaker": "patient",
            "nlu": nlu,
            "dialog_state": dialog_state
        })
        save_json(logs, log_path)

        if turn_id >= 50 and any(x.get('intent') == "nod_prompt_salutations" for x in nlu):
            status = "successful"
            save_json(logs, log_path)
            break

        actions, response = pol_client.predict(dialog_history=dialog_history, dialog_state=dialog_state, last_action=last_action, return_gen=True)
        if len(actions) == 0:
            # Error in actions generation
            logs['utterances'].append({
                "turn_id": turn_id,
                "actions": actions,
                "text": None,
                "speaker": "doctor",
                "response": response,
            })
            status = 'action_generation_error'
            save_json(logs, log_path)
            break

        last_action = deepcopy(actions)

        # Predict nlg
        doctor_uttr = nlg_client.predict(dialog_history=dialog_history, actions=actions)
        logs['utterances'].append({
            "turn_id": turn_id,
            "text": doctor_uttr,
            "speaker": "doctor",
            "actions": actions
        })
        turn_id += 1
        save_json(logs, log_path)
        dialog_history.append(f"Doctor: {doctor_uttr}")
        # print("Doctor:", doctor_uttr)

        ret = requests.post(
            f"{server_url}/patient/{session_id}/execute_and_commit", json={
                'doctor_uttr': doctor_uttr,
                "actions": actions,
            }
        )
        try:
            ret.raise_for_status()
        except:
            print(actions)
            # Error in executing and committing
            logs['utterances'].append({
                "turn_id": turn_id,
                "actions": actions,
                "text": None,
                "speaker": "doctor",
                "response": response,
            })
            status = 'action_execution_error'
            save_json(logs, log_path)
            break
        obj = ret.json()
        patient_uttr, status = obj['patient_uttr'], obj['status']
        dialog_history.append(f"Patient: {patient_uttr}")

        if turn_id >= 30 and any(x.get('action') == "nod_prompt_salutations" for x in actions):
            status = "successful"
            save_json(logs, log_path)
            break

        if turn_id >= 100:
            # Too many turns
            status = 'too_many_turns'
            save_json(logs, log_path)
            break

    logs['status'] = status
    save_json(logs, log_path)

    return status


def main(args):
    with open(args.ids_path, 'r') as fp:
        dids = json.load(fp)
    print(f"Loaded {len(dids)} dids from {args.ids_path}")

    os.makedirs(args.tar_path, exist_ok=True)

    nlu_client = NLUModel("nlu")
    nlg_client = NLGModel("nlg")
    pol_client = POLModel(args.pol_model)

    print("Running in simulation mode")
    func = partial(
        run_single,
        server_url=args.server_url,
        nlu_client=nlu_client,
        nlg_client=nlg_client,
        pol_client=pol_client,
        tar_path=args.tar_path
    )

    results = []
    for did in tqdm(dids, desc="Processing Dids"):
        result = func(did)
        results.append(result)


if __name__ == '__main__':
    args = read_cli()
    main(args)
