# -*- coding: utf-8 -*-
"""
Created on 

@author: 
"""
import os
import zipfile
from nltk.tokenize import word_tokenize

from convlab.util.file_util import cached_path
from convlab.e2e.soloist.multiwoz.config import global_config as cfg
from convlab.e2e.soloist.multiwoz.soloist_net import SOLOIST
from convlab.dialog_agent import Agent
from utils import MultiWozReader

DEFAULT_DIRECTORY = os.path.dirname(os.path.abspath(__file__))
DEFAULT_ARCHIVE_FILE_URL = "https://bapengstorage.blob.core.windows.net/fileshare/soloist_multiwoz_data.zip"
DEFAULT_MODEL_URL = "https://bapengstorage.blob.core.windows.net/fileshare/soloist-model.zip"


class SOLOISTAgent(Agent):
    def __init__(self,
                 model_file=DEFAULT_MODEL_URL,
                 name='soloist'):
        """
        soloist agent initialization
        Args:
            model_file (str):
                trained model path or url. 
        Example:
            model = SOLOISTAgent()
        """
        super(SOLOISTAgent, self).__init__(name=name)
        print(DEFAULT_DIRECTORY)
        if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'./data')):
            print('Down load data from', DEFAULT_ARCHIVE_FILE_URL)
            archive_file = cached_path(DEFAULT_ARCHIVE_FILE_URL)
            archive = zipfile.ZipFile(archive_file, 'r')
            print('unzip to', os.path.join(DEFAULT_DIRECTORY,'./'))
            archive.extractall(os.path.join(DEFAULT_DIRECTORY,'./'))
            archive.close()
        model_path = os.path.join(DEFAULT_DIRECTORY,'soloist-model')
        if not os.path.exists(model_path):
            model_dir = os.path.dirname(model_path)
            if not os.path.exists(model_dir):
                os.makedirs(model_dir)
            print('Load from model_file param')
            print('down load data from', model_file)
            archive_file = cached_path(model_file)
            archive = zipfile.ZipFile(archive_file, 'r')
            print('unzip to', model_dir)
            archive.extractall(model_dir)
            archive.close()
        
        
        self.model = SOLOIST()
        self.init_session()   
        self.reader = MultiWozReader() 
        self.active_domains = []

    
    def init_session(self):
        """Reset the class variables to prepare for a new session."""
        self.hidden_states = {}
        self.state = {}
        self.history = []

    def update_dialog_state(self):
        self.state = {}
    
    def prepare_input(self, usr):

        self.history.append(usr)
        context = ' EOS '.join(self.history)
        context = '[e2e] ' + context

        return context

    def parse_belief_state_and_response(self, response):
        
        try:
            belief, response = response.split('EOS')
            response=' '.join(response.split())
        except Exception:
            belief,response = '',' '.join(response.split())
        
        states_str = ' '.join(belief.split()[2:])

        state_per_domains = []
        start_idx = 0
        idx = -1
        for idx,word in enumerate(states_str.split()):
            if '[' in word:
                if idx == 0:
                    continue
                else:
                    state_per_domains.append(' '.join(states_str.split()[start_idx:idx]))
                    start_idx = idx
        state_per_domains.append(' '.join(states_str.split()[start_idx:idx+1]))
        states = {}

        for state_per_domain in state_per_domains:
            if len(state_per_domain.split()) == 0:
                continue
            domain = state_per_domain.split()[0]
            domain = domain[1:-1]
            svs = ' '.join(state_per_domain.split()[1:]).split(',')
            for sv in svs:
                try:
                    s,v = sv.strip().split(' is ')
                except:
                    continue
                if not domain in states.keys():
                    states[domain] = {}
                states[domain][s.strip()] = v.strip()

        response = ' '.join(word_tokenize(response))
        response = response.replace('[ ','[').replace(' ]',']')

        return states, response

    def response(self, usr):
        """
        Generate agent response given user input.

        Args:
            observation (str):
                The input to the agent.
        Returns:
            response (str):
                The response generated by the agent.
        """
        usr = ' '.join(word_tokenize(usr))
        inputs = self.prepare_input(usr)
        belief_and_response = self.model.generate(inputs)
        belief_state, response = self.parse_belief_state_and_response(belief_and_response)

        self.history.append(response)
        self.active_domains.extend(list(belief_state.keys()))
        self.active_domains = list(set(self.active_domains))

        lexicalized_response = self.reader.restore(response, self.active_domains, belief_state)

        return lexicalized_response

if __name__ == '__main__':
    s = SOLOISTAgent()

    user = "I want to find a cheap restaurant in the center"
    system = s.response(user)
    print(user)
    print(system)
    "There are 15 cheap restaurants in the centre . What type of food do you want ?"

    user = "I would like to have chinese food"
    system = s.response(user)
    "There are 3 cheap chinese restaurants in the centre . Would you like me to make a reservation for you at 1 of them ?"
    print(user)
    print(system)

    user = "Yes, please reserve for two people at 6 pm on monday"
    system = s.response(user)
    print(user)
    print(system)
    "I have booked you at Charlie Chan . The reference number is 00000010 . Is there anything else i can help you with ?"

    