import numpy as np
from pathlib import Path
import copy
import json
import os
from tqdm import tqdm

room_names = ['cellar','basement','attic','foyer','great room','library','kitchen','parlor','bedroom','bathroom','study','living room','family room','dining room', 'sun room','garden room','tea room','mud room','garage','office','sitting room', 'pantry', 'recreation room', 'utility room', 'laundry room', 'ballroom', 'cloakroom', 'den', 'larder', 'nursery', 'antechamber', 'boudoir', 'boiler room', 'conservatory', 'drawing room', 'fitting room', 'game room', 'loft', 'pool room', 'powder room', 'screen porch', 'storeroom', 'solarium', 'wine cellar']

counters = ["one","two","three","four","five","six","seven","eight","nine","ten","eleven","twelve","thirteen","fourteen","fifteen","sixteen","seventeen","eighteen","nineteen","twenty"]

word_multipliers = [ "", "twice", "three times", "four times", "five times",
                     "six times", "seven times", "eight times", "nine times", "ten times",
                     "eleven times", "twelve times", "thirteen times", "fourteen times", "fifteen times" ]

int_multipliers = [ "", "2 times", "3 times", "4 times", "5 times",
                    "6 times", "7 times", "8 times", "9 times", "10 times",
                    "11 times", "12 times", "13 times", "14 times", "15 times" ]

# if you modify this, make sure you check the gen_map function
connectives  = ['to the north','to the south','to the east','to the west']
iconnectives = {
    'to the north':'to the south',
    'to the south':'to the north',
    'to the east':'to the west',
    'to the west':'to the east'}

# override these in opts below
person_names = ['David','Chris','Josh','Connor','Sterling','Iris','Corinne','Savannah','Lexi','Cayla']
object_names = ['apple','pencil','book','block','ball','toy','dish','shirt','shoe']
bin_names = ['bin','box','bag','carton','container','sack','cart']

opts = {
#    "qtype":"*",  # should be one of [ 'navroute', 'navresult', 'simpobj', 'hardobj' ] or '*', which samples uniformly from all four
    
    "qtype":"navresult",
    "num_instances":20000,  # validation: 40000
    
    #Training
    #"room_min":3,  
    #"room_max":8,  
    #Validation
    "room_min":3,  
    "room_max":12, 

    #Train
    #"step_cnt_min":1,  
    #"step_cnt_max":5, 
    #Validation
    "step_cnt_min":1, 
    "step_cnt_max":9, 
    
    "collapse_repeated_actions":False,
    "word_multipliers":True,  # set to false to use integer multipliers
   
    "people":False,
    "people_min":0,
    "people_max":1,
   
    "bins":False,
    "bincnt_min":2,
    "bincnt_max":4,
    
    "obj_type_min":0,
    "obj_type_max":2,
    "obj_cnt_min":1,
    "obj_cnt_max":2,

    "object_names_filename":"data/posdata/2000_most_concrete_davies_nouns_train.txt",  # or None to use the default list above
    "person_names_filename":"data/names.txt", # or None to use the default list above
    "bin_names_filename":None,    

    "write_dir":'data/navprobs/nrextrap',
    "train_tsv_file":'atr-train.tsv',
    "train_prefix_file":'prefixes.txt',
    "train_suffix_file":'targets.txt',

    "validation_tsv_file":'atr-validation.tsv',
    "validation_prefix_file":'prefixes.txt',
    "validation_suffix_file":'targets.txt',
    
    }

#
# ------------------------------------------------------------------------------------
#

def samp_int( opts, prefix ):
    return np.random.randint( low=opts[prefix+'_min'], high=opts[prefix+'_max']+1 )

def samp_from_set( ze_set, cnt=1 ):
    perm = np.random.permutation( range(len(ze_set)) )
    if cnt == 1:
        return ze_set[perm[0]]
    return [ ze_set[x] for x in perm[0:cnt] ]

def pluralize( cnt, obj ):
    if cnt == 1:
        return f"a {obj}"
    else:
        return f"{counters[cnt-1]} {obj}s"

#
# ------------------------------------------------------------------------------------
#

def gen_map( opts ):

    # generate a set of rooms
    num_rooms = samp_int( opts, "room" )
    orig_rooms = samp_from_set( room_names, num_rooms )

    # connect them with actions
    avail_rooms = orig_rooms.copy()
    root_room = avail_rooms.pop()
    
    room_set = [root_room] # this is the "root" room
    relations = []

    room_to_coords = {}  # room->coords
    room_to_coords[ root_room ] = (0,0)
    coords_to_room = {}  # coordinates->room
    coords_to_room[ (0,0) ] = root_room
    
    for i in range(len(avail_rooms)):
        next_room = avail_rooms.pop()

        # try to graft the next room into the map
        while True:
            root_room = samp_from_set( room_set ) # root room
            rel = samp_from_set( connectives ) # a candidate relation

            root_c = room_to_coords[root_room]  # coordinates of the "root" room

            if   rel == 'to the north':
                coord = ( root_c[0], root_c[1]+1 )
            elif rel == 'to the south':
                coord = ( root_c[0], root_c[1]-1 )
            elif rel == 'to the east':
                coord = ( root_c[0]+1, root_c[1] )
            elif rel == 'to the west':
                coord = ( root_c[0]-1, root_c[1] )
            else:
                raise Exception("wargh") # invalid direction

            if coord in coords_to_room:
                continue  # ambiguous map

            coords_to_room[coord] = next_room
            room_to_coords[next_room] = coord
            
            relations.append( (next_room,rel,root_room) )
            room_set.append( next_room )
            
            break
                
            
            # we're trying to connect next_room to rr using rel.
            # disallow if something is already connected to rr using rel.
            # already_used = False
            # for r in relations:
            #     if r[1]==rel and r[2]==rr:
            #         already_used = True
            #         break
            #     if r[1]==iconnectives[rel] and r[0]==rr:
            #         already_used = True
            #         break
            # if already_used:
            #     continue
            
    
    ze_map = {
        'rooms':orig_rooms,
        'relations':relations
        }
    
    return ze_map

#
# ------------------------------------------------------------------------------------
#

def gen_initial_state( opts, zemap ):
    state = []
    
    # sample some bins
    num_bins = samp_int( opts, "bincnt" )
    for ind in range(num_bins):

        while True:
            elem = []
            bname = samp_from_set( bin_names, 1 ) # bin name
            bloc = samp_from_set( zemap['rooms'], 1 ) # bin location
            already_used = False
            for s in state:
                if s[0]==bname and s[1]==bloc:
                    already_used = True
            if not already_used:
                break
        
        elem.append( bname )
        elem.append( bloc )

        # sample the contents of the bin
        contents = []
        bobjtypecnt = samp_int( opts, "obj_type" ) # number of different types of objects in the bin
        for typeind in range(bobjtypecnt):
            bobjtype = samp_from_set( object_names, 1 ) # type of the object
            bobjcnt = samp_int( opts, "obj_cnt" ) # count of the object
            contents.append( [bobjtype,bobjcnt] )
        elem.append( contents )
            
        state.append( elem )
    
    return state

def gen_trans( opts, zemap, state ):

    new_state = copy.deepcopy(state)

    #
    # source position of the object
    #
    
    # pick a container
    num_containers = len( state )
    if num_containers <= 1:
        return None, new_state
    src_cont_ind = np.random.randint( low=0, high=num_containers )
    
    # pick an object from the container
    num_objs = len( state[src_cont_ind][2] )
    if num_objs == 0:
        return None, new_state
    src_obj_ind = np.random.randint( low=0, high=num_objs )

    src_obj_name, src_obj_cnt = state[src_cont_ind][2][src_obj_ind]

    if src_obj_cnt == 1:
        move_cnt = 1
    else:
        move_cnt = np.random.randint( low=1, high=src_obj_cnt )

    #
    # destination of the object
    #
    
    # pick a container
    while True:
        dest_cont_ind = np.random.randint( low=0, high=num_containers )
        if dest_cont_ind != src_cont_ind:
            break
    
    action = [ src_cont_ind, src_obj_ind, move_cnt, dest_cont_ind ]

    new_state[src_cont_ind][2][src_obj_ind][1] -= move_cnt
    if new_state[src_cont_ind][2][src_obj_ind][1] == 0:
        del new_state[src_cont_ind][2][src_obj_ind] # remove this object
    
    new_state[dest_cont_ind][2].append( [src_obj_name,move_cnt] )

    return action, new_state

def print_state( state ):
    result = ""
    for cont in state:
        result += f"There is a {cont[0]} in the {cont[1]} containing "
        items_in_cont = cont[2]

        if len(items_in_cont) == 0:
            result += "nothing. "
            continue
        
        # process all but the last item
        for item in items_in_cont[0:-1]:
            if item[1]==1:
                result += f"a {item[0]}, "
            else:
                result += f"{pluralize(item[1],item[0])}, "

        # now do the last item
        item = items_in_cont[-1]
        if len( items_in_cont ) > 1:
            result += "and "
        if item[1]==1:
            result += f"a {item[0]}. "
        else:
            result += f"{item[1]} {item[0]}s. "
            
    return result
   
def print_action( state, action ):
    if action is None:
        return "Didn't take anything. "
    
    src_cont_id, obj_id, cnt, dest_cont_id = action

    src_bin,  src_room,  src_obj_list = state[src_cont_id]
    dest_bin, dest_room, dest_obj_list = state[dest_cont_id]
    
    result = f"Took {pluralize( cnt, src_obj_list[obj_id][0])} from the {src_bin} in the {src_room} and placed it in the {dest_bin} in the {dest_room}. "

    return result

def get_num_bins_in_room( opts, state, room ):
    cnt = 0
    for sbin, sroom, sobj_list in state:
        if sroom == room:
            cnt += 1
    return cnt

def print_route_action( opts, state, action, steps ):
    
    src_cont_id, obj_id, cnt, dest_cont_id = action

    src_bin,  src_room,  src_obj_list = state[src_cont_id]
    dest_bin, dest_room, dest_obj_list = state[dest_cont_id]
    
    result = f"Took {pluralize( cnt, src_obj_list[obj_id][0])} from the {src_bin} in the {src_room}. Went "

    if opts['collapse_repeated_actions']:
        steps = collapse_repeated_actions(steps)
    if len(steps) == 0:
        result += "nowhere. "
    else:
        steps_str = englishify_steps(opts,steps)        
        result += ", then ".join(steps_str) + ". "

    if get_num_bins_in_room( opts, state, dest_room ) == 1:
        result += "Placed it. "
    else:
        result += f"Placed it in the {dest_bin}. "  # XXX worried that this is too easy...

    return result

#
# ------------------------------------------------------------------------------------
#

def englishify_steps(opts,steps):
    if opts['word_multipliers']:
        mults = word_multipliers
    else:
        mults = int_multipliers

    eng_steps = []

    if opts['collapse_repeated_actions']:    
        for s in steps:
            if s[1] == 1:
                eng_steps.append( s[0] )
            else:
                eng_steps.append( s[0]+" "+mults[s[1]-1] )  # -1 for 0-based indexing
    else:
        eng_steps = steps
        
    return eng_steps

def collapse_repeated_actions( steps ):
    if len(steps) == 0:
        return steps
    
    new_steps = [(steps[0],1)]
    for s in steps[1:]:
        if s == new_steps[-1][0]:
            new_steps[-1] = (s,new_steps[-1][1]+1)
        else:
            new_steps.append((s,1))
    return new_steps

def rec_calc_route( zemap, start_room, end_room, visited_nodes, cur_room, path ):
    #print( f"{start_room}->{end_room}: {visited_nodes}, {cur_room}, {path}" )

    if cur_room == end_room:
        return True # done!

    visited_nodes.append( cur_room )
    
    # recurse along all possible relations
    retval = False
    for src,rel,dest in zemap['relations']:
        if cur_room == src:
            if dest in visited_nodes:
                continue
            path.append(iconnectives[rel])
            retval = rec_calc_route( zemap, start_room, end_room, visited_nodes, dest, path )
            if retval:
                break
            path.pop()
        if cur_room == dest:
            if src in visited_nodes:
                continue
            path.append(rel)
            retval = rec_calc_route( zemap, start_room, end_room, visited_nodes, src, path )
            if retval:
                break
            path.pop()
    
    return retval

def calc_route( zemap, start_room, end_room ):
    # brute force depth first search
    path = []
    retval = rec_calc_route( zemap, start_room, end_room, [], start_room, path )
    if not retval:
        raise Exception("wargh!")
    return path

#
# ------------------------------------------------------------------------------------
#

def gen_nav_question( opts, qtype ):
    if qtype == 'navroute':
        return gen_nav_route_question( opts )
    if qtype == 'navresult':
        return gen_nav_result_question( opts )

def gen_nav_route_question( opts ):

    step_cnt = samp_int( opts, "step_cnt" )

    while True:
        # attempt to generate a scenario with step_cnt steps
        zemap = gen_map(opts)
        start_room,end_room = samp_from_set( zemap['rooms'], 2 )
        steps = calc_route( zemap, start_room, end_room )

        if len(steps) != step_cnt:
            continue
        break
        
    # background state
    prefix = ""
    for i in zemap['relations']:
        prefix += f"The {i[0]} is {i[1]} of the {i[2]}. "

    # query
    prefix += f"To get from the {start_room} to the {end_room}, you must go "
        
    # answer
    if opts['collapse_repeated_actions']:
        steps = collapse_repeated_actions(steps)
    if len(steps) == 0:
        suffix = "nowhere."
    else:
        steps_str = englishify_steps(opts,steps)
        suffix = ", then ".join(steps_str) + "."

    return prefix, suffix
    
def gen_nav_result_question( opts ):
    
    step_cnt = samp_int( opts, "step_cnt" )

    while True:
        # attempt to generate a scenario with step_cnt steps
        zemap = gen_map(opts)
        start_room,end_room = samp_from_set( zemap['rooms'], 2 )
        steps = calc_route( zemap, start_room, end_room )

        if len(steps) != step_cnt:
            continue
        break
    
    # background state
    prefix = ""
    for i in zemap['relations']:
        prefix += f"The {i[0]} is {i[1]} of the {i[2]}. "

    if opts['collapse_repeated_actions']:
        steps = collapse_repeated_actions(steps)
    if len(steps) == 0:
        steps_str = "nowhere"
    else:
        steps_str = englishify_steps(opts,steps)
        steps_str = ", then ".join(steps_str)

    # query
    prefix += f"If you start in the {start_room} and go " + steps_str + ", you will end in the "

    # answer
    suffix = f"{end_room}."
    
    return prefix, suffix

#
# ------------------------------------------------------------------------------------
#

def gen_simple_obj_transport( opts ):
    zemap = gen_map(opts)
    prefix = ""
    state = gen_initial_state(opts,zemap)
    action, new_state = gen_trans( opts, zemap, state )
    prefix += print_state( state )
    prefix += print_action( state, action )
    suffix = print_state( new_state )

    return prefix, suffix

def gen_hard_obj_transport( opts, zemap ):
    zemap = gen_map(opts)    
#    print("------------------------------------")
    prefix = ""
    for i in zemap['relations']:
        prefix += f"The {i[0]} is {i[1]} of the {i[2]}. "
    
    state = gen_initial_state(opts,zemap)
#    print(state)
    action, new_state = gen_trans( opts, zemap, state )
#    print(action)
#    print(new_state)
    
    prefix += print_state( state )

    if action is None:
        prefix += "Didn't take anything. "
    else:
        src_cont_id, obj_id, cnt, dest_cont_id = action
        src_bin,  src_room,  src_obj_list = state[src_cont_id]
        dest_bin, dest_room, dest_obj_list = state[dest_cont_id]
        steps = calc_route( zemap, src_room, dest_room )
        prefix += print_route_action( opts, state, action, steps )

    suffix = print_state( new_state )

    return prefix, suffix

#
# ------------------------------------------------------------------------------------
#

def gen_scenarios(opts,num_instances):
    scenarios = []
    prefixes = []
    suffixes = []
    for i in tqdm( range(num_instances) ):
    
        # pick a random query type: navigation, object movement+question
        qtype = opts['qtype']
        if qtype == "*":
            qtype = samp_from_set( [ 'nav', 'navresult', 'simpobj', 'hardobj' ] )

        if qtype=='navroute' or qtype=='navresult':
            prefix, suffix = gen_nav_question( opts, qtype )
    
        elif qtype=='simpobj':
            prefix, suffix = gen_simple_obj_transport( opts )
        
        elif qtype=='hardobj':
            prefix, suffix = gen_hard_obj_transport( opts )

        else:
            raise Exception("wargh!")

        scenario = f"{prefix}\t{suffix}"
        
        prefixes.append(prefix)
        suffixes.append(suffix)
        scenarios.append(scenario)
        
    return scenarios, prefixes, suffixes
        
def make_nonexistent_dirs(filename, overwrite=True):
    path = Path(filename)
    dirs = path.parent
    os.makedirs(dirs, exist_ok=True)
 
def write_train_val(opts, overwrite = False):
    NUM_INSTANCES = opts['num_instances']
    
    train_file = os.path.join(opts['write_dir'],'tsvs', opts['train_tsv_file'])
    train_prefix_file = os.path.join(opts['write_dir'],'train', opts['train_prefix_file'])
    train_target_file = os.path.join(opts['write_dir'],'train', opts['train_suffix_file'])

    validation_file = os.path.join(opts['write_dir'] , 'tsvs', opts['validation_tsv_file'])
    validation_prefix_file = os.path.join(opts['write_dir'], 'validation', opts['validation_prefix_file'])
    validation_target_file = os.path.join(opts['write_dir'], 'validation', opts['validation_suffix_file'])

    make_nonexistent_dirs(train_file, overwrite=overwrite)
    make_nonexistent_dirs(train_target_file, overwrite=overwrite)
    make_nonexistent_dirs(train_target_file, overwrite=overwrite)
    

    make_nonexistent_dirs(validation_file)
    make_nonexistent_dirs(validation_target_file)
    make_nonexistent_dirs(validation_target_file)

    if (not os.path.exists(train_file) or not os.path.exists(validation_file)) or overwrite==True:
        train_n = int(NUM_INSTANCES*0)
        val_n = int(NUM_INSTANCES*1)
        tsc, tp, ts = gen_scenarios(opts, num_instances=train_n)
        vsc, vp, vs = gen_scenarios(opts, num_instances=val_n)

        with open(train_file,'w') as tf,\
            open(train_prefix_file, 'w') as tpf, \
            open(train_target_file, 'w') as ttf:
            for scenario, prefix, suffix in zip(tsc, tp, ts):
                tf.write(scenario + '\n')
                tpf.write(prefix + '\n')
                ttf.write(suffix + '\n')
                
        with open(validation_file,'w') as vf, \
            open(validation_prefix_file, 'w') as vpf, \
            open(validation_target_file, 'w') as vtf:
            for scenario, prefix, suffix in zip(vsc, vp, vs):
                vf.write(scenario + '\n')
                vpf.write(prefix + '\n')
                vtf.write(suffix + '\n')
        
        
        json_counts = {"train": train_n, "validation": val_n}
        with open(os.path.join(opts['write_dir'], 'tsvs', 'atr-counts.json'),'w') as f:
            json.dump(json_counts, f)

#
# =======================================================================
# MAIN
# =======================================================================
#

if opts["object_names_filename"] is not None:
    lines = open(opts["object_names_filename"],"r").readlines()
    object_names = [ l.strip() for l in lines ]

if opts["person_names_filename"] is not None:
    lines = open(opts["person_names_filename"],"r").readlines()
    person_names = [ l.strip() for l in lines ]
    
if opts["bin_names_filename"] is not None:
    lines = open(opts["bin_names_filename"],"r").readlines()
    bin_names = [ l.strip() for l in lines ]
            
write_train_val(opts, True)
