import babyai
import gym
from babyai.bot import GoNextToSubgoal, PickupSubgoal, DropSubgoal, OpenSubgoal, LanguageObj
from babyai.utils.agent import FullyObsBotAgent, BotAgent

import numpy as np
import time
import logging
from collections import defaultdict
import itertools

from lang_hrl.envs.babyai_wrappers import FullyObsLanguageWrapper, WORD_TO_IDX, LanguageWrapper
from .datasets import Dataset, BabyAITrajectoryDataset

logger = logging.getLogger(__name__)
logging.basicConfig(level='INFO', format="%(asctime)s: %(levelname)s: %(message)s")

def merge_ids(*args):
    final_ids = []
    for arg in args:
        if isinstance(arg, int):
            final_ids.append(arg)
        elif isinstance(args, tuple):
            final_ids.extend(arg)
    return tuple(final_ids)

def process_subgoals(subgoals):
    seen_ids = set()
    used_subgoals = []
    time_to_id = []
    for subgoal in subgoals:
        time_to_id.append(subgoal.id)

    for subgoal in reversed(subgoals):
        if not subgoal.id in seen_ids:
            used_subgoals.append(subgoal)
            seen_ids.add(subgoal.id)
    subgoals = list(reversed(used_subgoals))
    combine_inds = []
    combine_instrs = []
    invalid_inds = set()

    # Remove doubles
    for i in range(len(subgoals) - 1):
        if subgoals[i].subgoal_type is GoNextToSubgoal and \
            not subgoals[i].obj_type is None and \
            subgoals[i+1].subgoal_type is PickupSubgoal and \
            not i in invalid_inds and \
            not i+1 in invalid_inds and \
            subgoals[i].pos == subgoals[i+1].pos:
            # print("COMBINING DOUBLE", subgoals[i].text, subgoals[i].id, subgoals[i+1].text, subgoals[i+1].id)

            # Removing duplicate
            combine_inds.append([i, i+1])
            text = "pick up the " + subgoals[i+1].obj_color + " " + subgoals[i+1].obj_type
            combine_instrs.append(LanguageObj(text=text, obj_type=None, 
                                         obj_color=subgoals[i+1].obj_color, subgoal_type=PickupSubgoal, 
                                         id=merge_ids(subgoals[i].id, subgoals[i+1].id), pos=subgoals[i+1].pos))
            invalid_inds.add(i)
            invalid_inds.add(i+1)

    # remove triples
    for i in range(len(subgoals) - 2):
        if subgoals[i].subgoal_type is PickupSubgoal and \
            subgoals[i+1].subgoal_type is GoNextToSubgoal and \
            subgoals[i+2].subgoal_type is DropSubgoal and \
            not i in invalid_inds and \
            not i+1 in invalid_inds and \
            not i+2 in invalid_inds and \
            subgoals[i].obj_color == subgoals[i+2].obj_color and \
            subgoals[i].obj_type == subgoals[i+2].obj_type:
            # Moving an object out of the way.
            # Compute the distance traveled:
            # print("COMBINING TRIPLE", subgoals[i].text, subgoals[i].id, subgoals[i+1].text, subgoals[i+1].id, subgoals[i+2].text, subgoals[i+2].id)

            dist_traveled = 0
            dist_traveled += abs(subgoals[i].pos[0] - subgoals[i+1].pos[0]) + abs(subgoals[i].pos[1] - subgoals[i+1].pos[1])
            dist_traveled += abs(subgoals[i+1].pos[0] - subgoals[i+2].pos[0]) + abs(subgoals[i+1].pos[1] - subgoals[i+2].pos[1])
            if dist_traveled <= 5:
                invalid_inds.add(i)
                invalid_inds.add(i+1)
                invalid_inds.add(i+2)
                combine_inds.append([i, i+1, i+2])
                text = "move the " + subgoals[i].obj_color + " " + subgoals[i].obj_type
                combine_instrs.append(
                    LanguageObj(text=text, obj_type=None, 
                                         obj_color=None, subgoal_type=None, 
                                         id=merge_ids(subgoals[i].id, subgoals[i+1].id, subgoals[i+2].id), pos=None)
                )

    for inds, instr in reversed(sorted(zip(combine_inds, combine_instrs), key=lambda x: x[0][0])):
        for ind in reversed(sorted(inds)):
            del subgoals[ind]
        subgoals.insert(ind, instr)

    # Now clear the subgoals that are blank
    for i in reversed(range(len(subgoals))):
        if subgoals[i].subgoal_type is GoNextToSubgoal and subgoals[i].obj_type is None:
            del subgoals[i]

    # Combine any remaining pickup-drop instruction combinations
    for i in reversed(range(len(subgoals) - 1)):
        
        if subgoals[i].subgoal_type is PickupSubgoal and \
            subgoals[i+1].subgoal_type is DropSubgoal:
            dist_traveled = abs(subgoals[i].pos[0] - subgoals[i+1].pos[0]) + abs(subgoals[i].pos[1] - subgoals[i+1].pos[1])
            if dist_traveled <= 5:
                text, obj_type, obj_color, subgoal_type = subgoals[i].text, subgoals[i].obj_type, subgoals[i].obj_color, subgoals[i].subgoal_type
                subgoals[i] = LanguageObj(text=text, obj_type=obj_type, obj_color=obj_color, subgoal_type=subgoal_type,
                                             id=merge_ids(subgoals[i].id, subgoals[i+1].id), pos=subgoals[i].pos)
                del subgoals[i+1] # delete the drop instr, will implicitly work if we pick up another obj
    
    # Combine any remaining gotodoor and open door instructions
    inds_to_del = []
    for i in reversed(range(len(subgoals) - 1)):
        if subgoals[i].subgoal_type is GoNextToSubgoal and \
                subgoals[i+1].subgoal_type is OpenSubgoal and \
                subgoals[i].pos == subgoals[i+1].pos: 
            # We can remove the the GoNextToSubgoal
            inds_to_del.append(i)
    for ind in inds_to_del:
        del subgoals[ind] # Should delete in reverse order

    # figure out the best way to sort the subgoals
    cur_subgoal_index = 0
    seen = defaultdict(lambda: False)
    new_subgoal_index = []
    for t, subgoal_id in enumerate(time_to_id):
        new_subgoal_index.append(cur_subgoal_index)
        seen[subgoal_id] = True
        cur_subgoal = subgoals[cur_subgoal_index]
        if isinstance(cur_subgoal.id, int) and seen[cur_subgoal.id]:
            cur_subgoal_index += 1
        elif isinstance(cur_subgoal.id, tuple) and all(seen[_id] for _id in cur_subgoal.id):
            cur_subgoal_index += 1
        cur_subgoal_index = min(cur_subgoal_index, len(subgoals) - 1)
    
    # Convert the subgoals to text. Remove the end token that is actually the END token instead of the end subgoal token.
    subgoals = [FullyObsLanguageWrapper.convert_subgoals_to_data([subgoal.text], max_len=-1, pad=False)[:-1] for subgoal in subgoals]
    return subgoals, new_subgoal_index
    
def generate_demos(env_name, seed, n_episodes, max_mission_len=None, log_interval=500, fully_obs=True, **kwargs):

    env = gym.make(env_name)
    if fully_obs:
        agent = FullyObsBotAgent(env)
        env = FullyObsLanguageWrapper(env, max_len=max_mission_len, pad=False)
    else:
        # Use the regular language wrapper
        agent = BotAgent(env)
        env = LanguageWrapper(env, max_len=max_mission_len, pad=False)

    demos = []

    checkpoint_time = time.time()
    just_crashed = False

    while True:
        if len(demos) == n_episodes:
            break
        done = False
        if just_crashed:
            logger.info("reset the environment to find a mission that the bot can solve")
            env.reset()
        else:
            env.seed(seed + len(demos)) # deterministically seed the demos
        
        obs = env.reset()
        agent.on_reset()

        # Construct the parts that will contain the data
        subgoals = []
        actions = []
        images = []
        mission = obs['mission']
        inventorys = []

        try:
            while not done:
                agent_action = agent.act(obs)
                action = agent_action['action']
                subgoals.append(agent_action['subgoal'].as_language())
                # env.render()
                new_obs, reward, done, _ = env.step(action)
                agent.analyze_feedback(reward, done)

                actions.append(action)
                images.append(obs['image'])
                if 'inventory' in obs:
                    inventorys.append(obs['inventory'])
                obs = new_obs

            # Append the final obs
            images.append(obs['image'])
            if 'inventory' in obs:
                inventorys.append(obs['inventory'])

            if reward > 0:
                images = np.array(images, dtype=np.uint8)
                if len(inventorys) > 0:
                    inventorys = np.array(inventorys, dtype=np.uint8)
                else:
                    inventorys = None
                actions = np.array(actions, dtype=np.int32)
                assert len(subgoals) == len(actions)
                subgoals, subgoal_inds = process_subgoals(subgoals)

                # # Run back the data collection in order to get the demos.
                # env.seed(seed + len(demos))
                # obs = env.reset()
                # agent.on_reset()
                # done = False
                # for i in range(len(subgoals)):
                #     agent_action = agent.act(obs)
                #     action = agent_action['action']
                #     print(subgoals[i])
                #     env.render()
                #     new_obs, reward, done, _ = env.step(action)
                #     input("Press enter to continue")
                    
                demos.append((mission, images, inventorys, actions, subgoals, subgoal_inds))
                just_crashed = False

            else:
                just_crashed = True
                logger.info("mission failed")

        except (Exception, AssertionError):
            just_crashed = True
            logger.exception("error while generating demo #{}".format(len(demos)))
            continue

        if len(demos) and len(demos) % log_interval == 0:
            now = time.time()
            demos_per_second = log_interval / (now - checkpoint_time)
            to_go = (n_episodes - len(demos)) / demos_per_second
            logger.info("demo #{}, {:.3f} demos per second, {:.3f} seconds to go".format(
                len(demos) - 1, demos_per_second, to_go))
            checkpoint_time = now

    return demos


def create_dataset(path, dataset_type, env_name, seed, n_episodes, fully_obs=False, max_mission_len=None, max_subgoals_len=None, skip=-1):
    demos = generate_demos(env_name, seed, n_episodes, fully_obs=fully_obs, max_mission_len=max_mission_len)

    missions, imgs, ac, subgoals = [], [], [], []
    for demo in demos:
        
        # Pad the mission out to the entire length. This can be done once per demo
        mission_pad = np.zeros(max_mission_len, dtype=np.int32)
        mission_pad[:len(demo[0])] = demo[0]
        missions.append(mission_pad)

        # Get points where the subgoal index changes
        inds = np.where(np.roll(demo[5], 1) != demo[5])[0]
        inds = np.append(inds, [len(demo[1]) - 1]) # add the final ending index for the end of the episode.
        for i in range(len(inds) - 1):
            sg_start, sg_end = inds[i], inds[i+1]
            images = demo[1][:sg_end] # End at the end of the episode
            actions = demo[3][:sg_end]
            actions[:sg_start] = -100 # This is the pad index for actions
            
            subgoal = np.zeros(max_subgoals_len, dtype=np.int32)
            subgoal[0] = WORD_TO_IDX['END_MISSION'] # add the start token
            subgoal[1:len(demo[4][i])+1] = demo[4][i]

            # Now that we have all the info add the mission. 
            missions.append(mission_pad.copy())
            imgs.append(images)
            ac.append(actions)
            subgoals.append(subgoal)

    dataset = BabyAITrajectoryDataset(imgs, missions, subgoals, ac, masks=None) # we have no masks
    dataset.save(path + "_hier")

