import json, random, sqlite3
from .ontology import all_domains, db_domains

class MultiWozDB(object):
    def __init__(self, db_paths):
        self.dbs = {}
        self.sql_dbs = {}
        for domain in all_domains:
            with open(db_paths[domain], 'r') as f:
                self.dbs[domain] = json.loads(f.read().lower())


    def oneHotVector(self, domain, num):
        """Return number of available entities for particular domain."""
        vector = [0,0,0,0]
        if num == '':
            return vector
        if domain != 'train':
            if num == 0:
                vector = [1, 0, 0, 0]
            elif num == 1:
                vector = [0, 1, 0, 0]
            elif num <=3:
                vector = [0, 0, 1, 0]
            else:
                vector = [0, 0, 0, 1]
        else:
            if num == 0:
                vector = [1, 0, 0, 0]
            elif num <= 5:
                vector = [0, 1, 0, 0]
            elif num <=10:
                vector = [0, 0, 1, 0]
            else:
                vector = [0, 0, 0, 1]
        return vector


    def addBookingPointer(self, turn_da):
        """Add information about availability of the booking option."""
        # Booking pointer
        # Do not consider booking two things in a single turn.
        vector = [0, 0]
        if turn_da.get('booking-nobook'):
            vector = [1, 0]
        if turn_da.get('booking-book') or turn_da.get('train-offerbooked'):
            vector = [0, 1]
        return vector


    def addDBPointer(self, domain, match_num, return_num=False):
        """Create database pointer for all related domains."""
        # if turn_domains is None:
        #     turn_domains = db_domains
        if domain in db_domains:
            vector = self.oneHotVector(domain, match_num)
        else:
            vector = [0, 0, 0 ,0]
        return vector

    def addDBIndicator(self, domain, match_num, return_num=False):
        """Create database indicator for all related domains."""
        # if turn_domains is None:
        #     turn_domains = db_domains
        if domain in db_domains:
            vector = self.oneHotVector(domain, match_num)
        else:
            vector = [0, 0, 0 ,0]
        
        # '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]'
        if vector == [0,0,0,0]:
            indicator = '[db_nores]'
        else:
            indicator = '[db_%s]' % vector.index(1)
        return indicator

    def get_match_num(self, constraints, return_entry=False):
        """Create database pointer for all related domains."""
        match = {'general': ''}
        entry = {}
        # if turn_domains is None:
        #     turn_domains = db_domains
        for domain in all_domains:
            match[domain] = ''
            if domain in db_domains and constraints.get(domain):
                matched_ents = self.queryJsons(domain, constraints[domain])
                match[domain] = len(matched_ents)
                if return_entry :
                    entry[domain] = matched_ents
        if return_entry:
            return entry
        return match


    def pointerBack(self, vector, domain):
        # multi domain implementation
        # domnum = cfg.domain_num
        if domain.endswith(']'):
            domain = domain[1:-1]
        if domain != 'train':
            nummap = {
                0: '0',
                1: '1',
                2: '2-3',
                3: '>3'
            }
        else:
            nummap = {
                0: '0',
                1: '1-5',
                2: '6-10',
                3: '>10'
            }
        if vector[:4] == [0,0,0,0]:
            report = ''
        else:
            num = vector.index(1)
            report = domain+': '+nummap[num] + '; '

        if vector[-2] == 0 and vector[-1] == 1:
            report += 'booking: ok'
        if vector[-2] == 1 and vector[-1] == 0:
            report += 'booking: unable'

        return report


    def queryJsons(self, domain, constraints, exactly_match=True, return_name=False):
        """Returns the list of entities for a given domain
        based on the annotation of the belief state
        constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'}
        """
        # query the db
        if domain == 'taxi':
            return [{'taxi_colors': random.choice(self.dbs[domain]['taxi_colors']),
            'taxi_types': random.choice(self.dbs[domain]['taxi_types']),
            'taxi_phone': [random.randint(1, 9) for _ in range(10)]}]
        if domain == 'police':
            return self.dbs['police']
        if domain == 'hospital':
            if constraints.get('department'):
                for entry in self.dbs['hospital']:
                    if entry.get('department') == constraints.get('department'):
                        return [entry]
            else:
                return []

        valid_cons = False
        for v in constraints.values():
            if v not in ["not mentioned", ""]:
                valid_cons = True
        if not valid_cons:
            return []

        match_result = []


        if 'name' in constraints:
            for db_ent in self.dbs[domain]:
                if 'name' in db_ent:
                    cons = constraints['name']
                    dbn = db_ent['name']
                    if cons == dbn:
                        db_ent = db_ent if not return_name else db_ent['name']
                        match_result.append(db_ent)
                        return match_result

        for db_ent in self.dbs[domain]:
            match = True
            for s, v in constraints.items():
                if s == 'name':
                    continue
                if s in ['people', 'stay'] or(domain == 'hotel' and s == 'day') or \
                (domain == 'restaurant' and s in ['day', 'time']):
                    continue

                skip_case = {"don't care":1, "do n't care":1, "dont care":1, "not mentioned":1, "dontcare":1, "":1}
                if skip_case.get(v):
                    continue

                if s not in db_ent:
                    # logging.warning('Searching warning: slot %s not in %s db'%(s, domain))
                    match = False
                    break

                # v = 'guesthouse' if v == 'guest house' else v
                # v = 'swimmingpool' if v == 'swimming pool' else v
                v = 'yes' if v == 'free' else v

                if s in ['arrive', 'leave']:
                    try:
                        h,m = v.split(':')   # raise error if time value is not xx:xx format
                        v = int(h)*60+int(m)
                    except:
                        match = False
                        break
                    time = int(db_ent[s].split(':')[0])*60+int(db_ent[s].split(':')[1])
                    if s == 'arrive' and v>time:
                        match = False
                    if s == 'leave' and v<time:
                        match = False
                else:
                    if exactly_match and v != db_ent[s]:
                        match = False
                        break
                    elif v not in db_ent[s]:
                        match = False
                        break

            if match:
                match_result.append(db_ent)

        if not return_name:
            return match_result
        else:
            if domain == 'train':
                match_result = [e['id'] for e in match_result]
            else:
                match_result = [e['name'] for e in match_result]
            return match_result


    def querySQL(self, domain, constraints):
        if not self.sql_dbs:
            for dom in db_domains:
                db = 'db/{}-dbase.db'.format(dom)
                conn = sqlite3.connect(db)
                c = conn.cursor()
                self.sql_dbs[dom] = c

        sql_query = "select * from {}".format(domain)


        flag = True
        for key, val in constraints.items():
            if val == "" or val == "dontcare" or val == 'not mentioned' or val == "don't care" or val == "dont care" or val == "do n't care":
                pass
            else:
                if flag:
                    sql_query += " where "
                    val2 = val.replace("'", "''")
                    # val2 = normalize(val2)
                    if key == 'leaveAt':
                        sql_query += r" " + key + " > " + r"'" + val2 + r"'"
                    elif key == 'arriveBy':
                        sql_query += r" " + key + " < " + r"'" + val2 + r"'"
                    else:
                        sql_query += r" " + key + "=" + r"'" + val2 + r"'"
                    flag = False
                else:
                    val2 = val.replace("'", "''")
                    # val2 = normalize(val2)
                    if key == 'leaveAt':
                        sql_query += r" and " + key + " > " + r"'" + val2 + r"'"
                    elif key == 'arriveBy':
                        sql_query += r" and " + key + " < " + r"'" + val2 + r"'"
                    else:
                        sql_query += r" and " + key + "=" + r"'" + val2 + r"'"

        try:  # "select * from attraction  where name = 'queens college'"
            print(sql_query)
            return self.sql_dbs[domain].execute(sql_query).fetchall()
        except:
            return []  # TODO test it

