"""
This conventional pipeline depends on deepspeech speech recognition and NLTK for NLP processing
Assume the checkpoint for the deepspeech has been saved in data/deepspeech
"""
import deepspeech
from Envs.audioLoader import audioLoader
from cfg import main_config, gym_register
from rasa.nlu.model import Interpreter
from Envs.vec_env.envs import make_vec_envs
from models.ppo.model import Policy
import torch
import numpy as np
import os
import pandas as pd

config=main_config()
gym_register(config)


def rasa_output(interpreter, text):
    """
        Function to get model output
        Args:
          text  (string)  --  input text string to be passed)
        For example: if you are interested in entities, you can just write result['entities']
        Returns:
          json  --  json output to used for accessing model output
    """
    message = str(text).strip()
    result = interpreter.parse(message)
    return result


def main():
    device = torch.device("cuda:0")
    num_processes = 1
    # create the env
    eval_envs = make_vec_envs(env_name=config.RLEnvName,
                              seed=config.RLEnvSeed,
                              num_processes=num_processes,
                              gamma=None,
                              device=device,
                              randomCollect=False,
                              config=config)
    baseEnv = eval_envs.venv.unwrapped.envs[0]

    # load the trained RL policy
    actor_critic = Policy(
        eval_envs.venv.observation_space.spaces,
        eval_envs.action_space,
        config=config,
        base=config.RLPolicyBase,
        base_kwargs={'recurrent': config.RLRecurrentPolicy,
                     'recurrentInputSize': config.RLRecurrentInputSize,
                     'recurrentSize': config.RLRecurrentSize,
                     'actionHiddenSize': config.RLActionHiddenSize
                     })

    assert (config.RLModelLoadDir is not None)
    print("Load the weights from", config.RLModelLoadDir)
    actor_critic.load_state_dict(torch.load(config.RLModelLoadDir))

    actor_critic.eval()
    print("Weights Loaded!")
    actor_critic.to(device)

    # ASR
    audio = baseEnv.audio
    model_file_path = '../data/deepspeech/deepspeech-0.9.3-models.pbmm'
    scorer_file_path = '../data/deepspeech/deepspeech-0.9.3-models.scorer'
    model = deepspeech.Model(model_file_path)
    model.enableExternalScorer(scorer_file_path)
    lm_alpha = 0.75
    lm_beta = 1.85
    model.setScorerAlphaBeta(lm_alpha, lm_beta)
    beam_width = 500
    model.setBeamWidth(beam_width)
    audio_size=config.soundSource['size']

    # NLU
    rasa_model_path = "NLU_RSI3/models/20220719-134456/nlu"

    # create an interpreter object
    interpreter = Interpreter.load(rasa_model_path)

    eval_episode_rewards = []
    eval_env_rewards = 0.
    eval_recurrent_hidden_states = torch.zeros(
        num_processes, actor_critic.recurrent_hidden_state_size, device=device)
    eval_masks = torch.zeros(num_processes, 1, device=device)

    results = []
    goal_area_count_list = []
    objs = np.arange(config.taskNum, dtype=np.int64)
    objs = np.repeat(objs, baseEnv.size_per_class)

    intent2idx={'lamp_on':0,'lamp_off':1, 'music_on':2, 'music_off':3, 'pickup_shoes':4, 'wrong':5}

    def get_predicted_intent():
        a = baseEnv.goal_audio
        txt = model.stt(a)
        print(txt)
        nlu_out = rasa_output(interpreter, txt)
        ret=nlu_out['intent']['name']
        print(ret)
        return ret

    # evaluate
    obs = eval_envs.reset()
    intent=get_predicted_intent()

    episode_num = baseEnv.size_per_class_cumsum[-1]

    while baseEnv.episodeCounter < episode_num:

        one_hot_intent=np.zeros((config.taskNum,))
        if intent!='nlu_fallback' and intent2idx[intent]!=config.taskNum:
            one_hot_intent[intent2idx[intent]]=1.

            obs['soundLabel'][:config.taskNum] = torch.from_numpy(one_hot_intent)
            with torch.no_grad():
                _, action, _, eval_recurrent_hidden_states = actor_critic.act(
                    obs,
                    eval_recurrent_hidden_states,
                    eval_masks,
                    deterministic=True)

            # Obser reward and next obs
            obs, _, done, infos = eval_envs.step(action)
            if config.render:
                eval_envs.render()

                print('step reward', eval_envs.venv.origStepReward)
            eval_env_rewards = eval_env_rewards + eval_envs.venv.origStepReward

            eval_masks = torch.tensor(
                [[0.0] if done_ else [1.0] for done_ in done],
                dtype=torch.float32,
                device=device)

            if done:
                goal_area_count = infos[0]['goal_area_count']
                goal_area_count_list.append(goal_area_count)
                results.append(int(goal_area_count >= config.success_threshold))
                eval_episode_rewards.append(float(eval_env_rewards))
                eval_env_rewards = 0.
                intent=get_predicted_intent()

        else: # ASR+NLU failure
            goal_area_count_list.append(0)
            results.append(0)
            eval_episode_rewards.append(float(eval_env_rewards))
            eval_env_rewards = 0.
            print('goal area count------------------------- 0')

            # start a new episode
            obs = eval_envs.reset() # episode counter will +1
            intent = get_predicted_intent()

    # save the results
    if not config.render:
        df = pd.DataFrame({'objIdx': objs, 'goal area count': goal_area_count_list, 'rewards': eval_episode_rewards,
                           'results': results})
        save_path = os.path.join(os.path.dirname(config.RLModelLoadDir),
                                 'test_conventional' + os.path.splitext(os.path.basename(config.RLModelLoadDir))[0] + '.csv')
        df.to_csv(save_path, mode='w', header=True, index=False)
        print('results saved to', save_path)
        print('success rate', sum(results) * 1. / baseEnv.size_per_class_cumsum[-1])
    eval_envs.close()


main()


def ASR_NLU_only():
    audio = audioLoader(config=config)
    model_file_path = 'data/deepspeech/deepspeech-0.9.3-models.pbmm'
    scorer_file_path = 'data/deepspeech/deepspeech-0.9.3-models.scorer'
    model = deepspeech.Model(model_file_path)
    model.enableExternalScorer(scorer_file_path)
    lm_alpha = 0.75
    lm_beta = 1.85
    model.setScorerAlphaBeta(lm_alpha, lm_beta)
    beam_width = 500
    model.setBeamWidth(beam_width)

    audio_size = config.soundSource['size']
    task_audio = [audio.words['none']['lights']['activate'][:audio_size // 2] + audio.words['none']['lamp']['activate'][
                                                                                :audio_size // 2],
                  audio.words['none']['lights']['deactivate'][:audio_size // 2] + audio.words['none']['lamp'][
                                                                                      'deactivate'][:audio_size // 2],
                  audio.words['none']['music']['activate'],
                  audio.words['none']['music']['deactivate']
                  ]

    # path of your model
    rasa_model_path = "NLU/models/nlu-20220414-142356/nlu"

    # create an interpreter object
    interpreter = Interpreter.load(rasa_model_path)

    intent_success_count = [0, 0, 0, 0, 0]
    intent_dict = {0: 'lamp_on', 1: 'lamp_off', 2: 'music_on', 3: 'music_off', 4: 'wrong'}
    for i, l in enumerate(task_audio):
        print("intent", i)
        for a in l:
            txt = model.stt(a)
            print(txt)
            nlu_out = rasa_output(interpreter, txt)
            intent = nlu_out['intent']['name']
            if intent == intent_dict[i]:
                intent_success_count[i] = intent_success_count[i] + 1
            else:
                pass
    print("success rate", sum(intent_success_count[:4]) / (audio_size * config.taskNum))
