import babyai
import gym
from babyai.bot import GoNextToSubgoal, PickupSubgoal, DropSubgoal, OpenSubgoal, LanguageObj
from babyai.utils.agent import FullyObsBotAgent, BotAgent

import numpy as np
import time
import logging
from collections import defaultdict
import itertools

from lang_hrl.envs.babyai_wrappers import FullyObsLanguageWrapper, WORD_TO_IDX, LanguageWrapper
from .datasets import Dataset, BabyAITrajectoryDataset

logger = logging.getLogger(__name__)
logging.basicConfig(level='INFO', format="%(asctime)s: %(levelname)s: %(message)s")

def merge_ids(*args):
    final_ids = []
    for arg in args:
        if isinstance(arg, int):
            final_ids.append(arg)
        elif isinstance(args, tuple):
            final_ids.extend(arg)
    return tuple(final_ids)

def process_subgoals(subgoals, aggressive_mask=False):
    seen_ids = set()
    used_subgoals = []
    time_to_id = []
    for subgoal in subgoals:
        time_to_id.append(subgoal.id)

    for subgoal in reversed(subgoals):
        if not subgoal.id in seen_ids:
            used_subgoals.append(subgoal)
            seen_ids.add(subgoal.id)
    subgoals = list(reversed(used_subgoals))
    combine_inds = []
    combine_instrs = []
    invalid_inds = set()

    # Remove doubles
    for i in range(len(subgoals) - 1):
        if subgoals[i].subgoal_type is GoNextToSubgoal and \
            not subgoals[i].obj_type is None and \
            subgoals[i+1].subgoal_type is PickupSubgoal and \
            not i in invalid_inds and \
            not i+1 in invalid_inds and \
            subgoals[i].pos == subgoals[i+1].pos:
            # print("COMBINING DOUBLE", subgoals[i].text, subgoals[i].id, subgoals[i+1].text, subgoals[i+1].id)

            # Removing duplicate
            combine_inds.append([i, i+1])
            text = "pick up the " + subgoals[i+1].obj_color + " " + subgoals[i+1].obj_type
            combine_instrs.append(LanguageObj(text=text, obj_type=None, 
                                         obj_color=subgoals[i+1].obj_color, subgoal_type=PickupSubgoal, 
                                         id=merge_ids(subgoals[i].id, subgoals[i+1].id), pos=subgoals[i+1].pos))
            invalid_inds.add(i)
            invalid_inds.add(i+1)

    # remove triples
    for i in range(len(subgoals) - 2):
        if subgoals[i].subgoal_type is PickupSubgoal and \
            subgoals[i+1].subgoal_type is GoNextToSubgoal and \
            subgoals[i+2].subgoal_type is DropSubgoal and \
            not i in invalid_inds and \
            not i+1 in invalid_inds and \
            not i+2 in invalid_inds and \
            subgoals[i].obj_color == subgoals[i+2].obj_color and \
            subgoals[i].obj_type == subgoals[i+2].obj_type:
            # Moving an object out of the way.
            # Compute the distance traveled:
            # print("COMBINING TRIPLE", subgoals[i].text, subgoals[i].id, subgoals[i+1].text, subgoals[i+1].id, subgoals[i+2].text, subgoals[i+2].id)

            dist_traveled = 0
            dist_traveled += abs(subgoals[i].pos[0] - subgoals[i+1].pos[0]) + abs(subgoals[i].pos[1] - subgoals[i+1].pos[1])
            dist_traveled += abs(subgoals[i+1].pos[0] - subgoals[i+2].pos[0]) + abs(subgoals[i+1].pos[1] - subgoals[i+2].pos[1])
            if dist_traveled <= 5:
                invalid_inds.add(i)
                invalid_inds.add(i+1)
                invalid_inds.add(i+2)
                combine_inds.append([i, i+1, i+2])
                text = "move the " + subgoals[i].obj_color + " " + subgoals[i].obj_type
                combine_instrs.append(
                    LanguageObj(text=text, obj_type=None, 
                                         obj_color=None, subgoal_type=None, 
                                         id=merge_ids(subgoals[i].id, subgoals[i+1].id, subgoals[i+2].id), pos=None)
                )

    for inds, instr in reversed(sorted(zip(combine_inds, combine_instrs), key=lambda x: x[0][0])):
        for ind in reversed(sorted(inds)):
            del subgoals[ind]
        subgoals.insert(ind, instr)

    # Now clear the subgoals that are blank
    for i in reversed(range(len(subgoals))):
        if subgoals[i].subgoal_type is GoNextToSubgoal and subgoals[i].obj_type is None:
            del subgoals[i]

    # Combine any remaining pickup-drop instruction combinations
    for i in reversed(range(len(subgoals) - 1)):
        
        if subgoals[i].subgoal_type is PickupSubgoal and \
            subgoals[i+1].subgoal_type is DropSubgoal:
            dist_traveled = abs(subgoals[i].pos[0] - subgoals[i+1].pos[0]) + abs(subgoals[i].pos[1] - subgoals[i+1].pos[1])
            if dist_traveled <= 5:
                text, obj_type, obj_color, subgoal_type = subgoals[i].text, subgoals[i].obj_type, subgoals[i].obj_color, subgoals[i].subgoal_type
                subgoals[i] = LanguageObj(text=text, obj_type=obj_type, obj_color=obj_color, subgoal_type=subgoal_type,
                                             id=merge_ids(subgoals[i].id, subgoals[i+1].id), pos=subgoals[i].pos)
                del subgoals[i+1] # delete the drop instr, will implicitly work if we pick up another obj
    
    # Combine any remaining gotodoor and open door instructions
    inds_to_del = []
    for i in reversed(range(len(subgoals) - 1)):
        if subgoals[i].subgoal_type is GoNextToSubgoal and \
                subgoals[i+1].subgoal_type is OpenSubgoal and \
                subgoals[i].pos == subgoals[i+1].pos: 
            # We can remove the the GoNextToSubgoal
            inds_to_del.append(i)
    for ind in inds_to_del:
        del subgoals[ind] # Should delete in reverse order

    # figure out the best way to sort the subgoals
    cur_subgoal_index = 0
    seen = defaultdict(lambda: False)
    new_subgoal_index = []
    for t, subgoal_id in enumerate(time_to_id):
        new_subgoal_index.append(cur_subgoal_index)
        seen[subgoal_id] = True
        cur_subgoal = subgoals[cur_subgoal_index]
        if isinstance(cur_subgoal.id, int) and seen[cur_subgoal.id]:
            cur_subgoal_index += 1
        elif isinstance(cur_subgoal.id, tuple) and all(seen[_id] for _id in cur_subgoal.id):
            cur_subgoal_index += 1
        cur_subgoal_index = min(cur_subgoal_index, len(subgoals) - 1)
    
    # Grab and the convert the text
    subgoals_per_timestep = []
    for ind in new_subgoal_index:
        text = [subgoal.text for subgoal in subgoals[ind:]]
        s_t = FullyObsLanguageWrapper.convert_subgoals_to_data(text, max_len=-1, pad=False)
        subgoals_per_timestep.append(s_t)

    # Get the total length of the text for all the subgoals. This is in the first subgoal for timestep
    # Need to construct a tensor of shape (T, S) where T is target seq and S is src seq.
    # Basically, we need need to block cross-attention from everywhere where the instruction is being executed.
    mask = np.ones((len(subgoals_per_timestep[0]), len(new_subgoal_index)), dtype=np.bool)
    # For each target, can attend to all timesteps where it hasn't been completed.
    for i, ind in enumerate(new_subgoal_index):
        before_size = sum([len(sg) for sg in subgoals[:ind + (1 if aggressive_mask else 0)]])
        mask[:before_size, i] = False
    if aggressive_mask:
        mask[:, 0] = True

    return subgoals_per_timestep, mask

def generate_demos(env_name, seed, n_episodes, max_mission_len=None, log_interval=500, fully_obs=True, **kwargs):

    env = gym.make(env_name)
    if fully_obs:
        agent = FullyObsBotAgent(env)
        env = FullyObsLanguageWrapper(env, max_len=max_mission_len, pad=False)
    else:
        # Use the regular language wrapper
        agent = BotAgent(env)
        env = LanguageWrapper(env, max_len=max_mission_len, pad=False)

    demos = []

    checkpoint_time = time.time()
    just_crashed = False

    while True:
        if len(demos) == n_episodes:
            break
        done = False
        if just_crashed:
            logger.info("reset the environment to find a mission that the bot can solve")
            env.reset()
        else:
            env.seed(seed + len(demos)) # deterministically seed the demos
        
        obs = env.reset()
        agent.on_reset()

        # Construct the parts that will contain the data
        subgoals = []
        actions = []
        images = []
        mission = obs['mission']
        inventorys = []

        try:
            while not done:
                agent_action = agent.act(obs)
                action = agent_action['action']
                subgoals.append(agent_action['subgoal'].as_language())
                # env.render()
                new_obs, reward, done, _ = env.step(action)
                agent.analyze_feedback(reward, done)

                actions.append(action)
                images.append(obs['image'])
                if 'inventory' in obs:
                    inventorys.append(obs['inventory'])
                obs = new_obs

            # Append the final obs
            images.append(obs['image'])
            if 'inventory' in obs:
                inventorys.append(obs['inventory'])

            if reward > 0:
                images = np.array(images, dtype=np.uint8)
                if len(inventorys) > 0:
                    inventorys = np.array(inventorys, dtype=np.uint8)
                else:
                    inventorys = None
                actions = np.array(actions, dtype=np.int32)
                assert len(subgoals) == len(actions)
                subgoals, mask = process_subgoals(subgoals)

                # # Run back the data collection in order to get the demos.
                # env.seed(seed + len(demos))
                # obs = env.reset()
                # agent.on_reset()
                # done = False
                # for i in range(len(subgoals)):
                #     agent_action = agent.act(obs)
                #     action = agent_action['action']
                #     print(subgoals[i])
                #     env.render()
                #     new_obs, reward, done, _ = env.step(action)
                #     input("Press enter to continue")
                    
                demos.append((mission, images, inventorys, actions, subgoals, mask))
                just_crashed = False

            else:
                just_crashed = True
                logger.info("mission failed")

        except (Exception, AssertionError) as e:
            print("EXCEPTION!", e)
            just_crashed = True
            logger.exception("error while generating demo #{}".format(len(demos)))
            continue

        if len(demos) and len(demos) % log_interval == 0:
            now = time.time()
            demos_per_second = log_interval / (now - checkpoint_time)
            to_go = (n_episodes - len(demos)) / demos_per_second
            logger.info("demo #{}, {:.3f} demos per second, {:.3f} seconds to go".format(
                len(demos) - 1, demos_per_second, to_go))
            checkpoint_time = now

    return demos

def create_bc_dataset(demos, seed, max_mission_len=None, max_subgoals_len=None):
    missions, imgs, invs, ac = [], [], [], []
    for demo in demos:
        ep_length = len(demo[1]) # Get the length of the missions
        mission_pad = np.zeros(max_mission_len, dtype=np.int32)
        mission_pad[:len(demo[0])] = demo[0]
        mission_expanded = np.tile(mission_pad, (ep_length, 1))
        missions.append(mission_expanded)

        imgs.append(demo[1][:-1]) # Remove the final state.
        invs.append(demo[2][:-1]) # Remove the final state.
        ac.append(demo[3])
    
    missions = np.concatenate(missions, axis=0)
    imgs = np.concatenate(imgs, axis=0)
    ac = np.concatenate(ac, axis=0)
    invs = np.concatenate(invs, axis=0)
    states = {'image': imgs, 'mission' : missions, 'inventory': invs}
    return Dataset(states=states, actions=ac)

def create_gpt_dataset(demos, seed, max_mission_len=None, max_subgoals_len=None):
    imgs, text, labels = [], [], []
    for demo in demos:
        imgs.append(demo[1][0:1, :, :, :]) # Get only the first image
        mission, subgoals = demo[0], demo[4][0]
        combined = np.concatenate((mission, subgoals), axis=0)
        combined = combined[:min(len(combined), max_subgoals_len+1)]
        text_x = np.zeros(max_subgoals_len, dtype=np.long)
        text_x[:combined.shape[0]-1] = combined[:-1]
        
        text_y = -100*np.ones(max_subgoals_len, dtype=np.long)
        text_y[:combined.shape[0]-1] = combined[1:]
        text_y[:len(mission) - 1] = -100

        text.append(text_x)
        labels.append(text_y)
    
    imgs = np.concatenate(imgs, axis=0)
    text = np.vstack(text)
    labels = np.vstack(labels)
    
    states = {'image': imgs, 'text': text}
    return Dataset(states=states, actions=labels)

def create_gpt_bc_dataset(demos, seed, max_mission_len=None, max_subgoals_len=None):
    texts, imgs, invs, ac, labels = [], [], [], [], []
    for demo in demos:
        ep_length = len(demo[3]) # Get the length of the missions
        mission_length = len(demo[0])
        # maximum lenghth is the sum of the mission and the subgoals        
        text = np.zeros(max_mission_len + max_subgoals_len, dtype=np.int32)
        text[:mission_length] = demo[0]
        text = np.tile(text, (ep_length, 1))
        # Now create the label block
        label = -100*np.ones((ep_length, max_mission_len + max_subgoals_len), dtype=np.int32)
        for i, subgoal in enumerate(demo[4]):
            end_index = min(mission_length+len(subgoal), text.shape[1])
            text[i, mission_length:end_index] = subgoal[:end_index - mission_length]
            label[i, mission_length:end_index] = subgoal[:end_index - mission_length]
        # Now trim the texts to the appropriate length
        text = text[:, :-1] # skip the last token for x
        label = label[:, 1:] # skip the first token for the y 

        imgs.append(demo[1][:-1]) # Remove the final state.
        invs.append(demo[2][:-1]) # Remove the final state.
        ac.append(demo[3])
        texts.append(text)
        labels.append(label)
    
    texts = np.concatenate(texts, axis=0)
    labels = np.concatenate(labels, axis=0)
    imgs = np.concatenate(imgs, axis=0)
    ac = np.concatenate(ac, axis=0)
    invs = np.concatenate(invs, axis=0)
    states = {'image': imgs, 'mission' : texts, 'inventory': invs, 'label': labels, 'inventory': invs}, 
    return Dataset(states=states, actions=ac)

def create_seq2seq_dataset(demos, seed, max_mission_len=None, max_subgoals_len=None):
    # need to add a start token to the subgoals
    imgs, missions, subgoals = [], [], []
    for demo in demos:
        imgs.append(demo[1][0:1, :, :, :]) # Get only the first image
        mission, subgoal = demo[0], demo[4][0]
        x = np.zeros(max_mission_len, dtype= np.int32)
        x[:len(mission)] = mission
        # Note that for seq2seq we are using ignore index 0 instead of -100
        y = np.zeros(max_subgoals_len, dtype=np.int32)
        y[0] = WORD_TO_IDX['END_MISSION']
        y[1:min(max_subgoals_len, len(subgoal)+1)] = subgoal[:min(max_subgoals_len - 1, len(subgoal))]

        missions.append(x)
        subgoals.append(y)

    imgs = np.concatenate(imgs, axis=0)
    missions = np.vstack(missions)
    subgoals = np.vstack(subgoals)
    states = {'image': imgs, 'missions': missions}
    return Dataset(states=states, actions=subgoals)


def create_seq2seq_bc_dataset(demos, seed, max_mission_len=None, max_subgoals_len=None):
    missions, imgs, invs, ac, subgoals = [], [], [], [], []
    for demo in demos:
        ep_length = len(demo[3]) # Get the length of the missions
        mission_length = len(demo[0])
        # pad and expand the mission
        mission_pad = np.zeros(max_mission_len, dtype=np.int32)
        mission_pad[:mission_length] = demo[0]
        mission_expanded = np.tile(mission_pad, (ep_length, 1))
        # Pad the subgoal
        label = -100*np.ones((ep_length, max_subgoals_len), dtype=np.int32)
        label[:, 0] = WORD_TO_IDX['END_MISSION']
        for i, sg in enumerate(demo[4]):
            label[i, 1:min(max_subgoals_len, len(sg)+1)] = sg[:min(max_subgoals_len - 1, len(sg))]

        missions.append(mission_expanded)
        imgs.append(demo[1][:-1]) # Remove the ending state
        invs.append(demo[2][:-1])
        ac.append(demo[3])
        subgoals.append(label)
    
    missions = np.concatenate(missions, axis=0)
    imgs = np.concatenate(imgs, axis=0)
    invs = np.concatenate(invs, axis=0)
    ac = np.concatenate(ac, axis=0)
    subgoals = np.concatenate(subgoals, axis=0)
    states = {'image': imgs, 'mission' : missions, 'label' : subgoals, 'inventory': invs}
    return Dataset(states=states, actions=ac)

def create_pretrained_bc_dataset(demos, seed, max_mission_len=None, max_subgoals_len=None):
    missions, imgs, invs, ac, masks = [], [], [], [], []
    IDX_TO_WORD = {v: k for k, v in WORD_TO_IDX.items()}
    from transformers import GPT2Tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='pad')

    for demo in demos:
        ep_length = len(demo[3]) # Get the length of the missions
        mission_text = ' '.join([IDX_TO_WORD[idx] for idx in demo[0]])
        tokenized_mission = tokenizer(mission_text, max_length=max_mission_len, padding='max_length', truncation=True)
        mission = np.array(tokenized_mission['input_ids'], dtype=np.int32)
        mission_expanded = np.tile(mission, (ep_length, 1))
        missions.append(mission_expanded)

        mask = np.array(tokenized_mission['attention_mask'], dtype=np.bool)
        mask_expanded = np.tile(mask, (ep_length, 1))
        masks.append(mask_expanded)

        imgs.append(demo[1][:-1]) # remove the final state
        invs.append(demo[2][:-1])
        ac.append(demo[3])
    
    missions = np.concatenate(missions, axis=0)
    masks = np.concatenate(masks, axis=0)
    imgs = np.concatenate(imgs, axis=0)
    invs = np.concatenate(invs, axis=0)
    ac = np.concatenate(ac, axis=0)
    states = {'image': imgs, 'mission' : missions, 'mask' : masks}
    return Dataset(states=states, actions=ac)

def create_seq2seq_dt_dataset(demos, seed, max_mission_len=None, max_subgoals_len=None):
    missions, imgs, ac, subgoals, masks = [], [], [], [], []
    for demo in demos:
        # The image sequence is shape (L, C, H, W)
        imgs.append(demo[1][:-1])
        ac.append(demo[3])

        # Pad the mission out to the entire length.
        mission_pad = np.zeros(max_mission_len, dtype=np.int32)
        mission_pad[:len(demo[0])] = demo[0]
        missions.append(mission_pad)

        # We need to add a start generation token to the subgoals
        subgoal = np.concatenate((np.array([WORD_TO_IDX['END_MISSION']], dtype=np.int32), demo[4][0]), axis=0)
        mask = np.concatenate((demo[5][0:1], demo[5]), axis=0) # Re-concatenate the first row mask for the subgoal.
        # Construct the subgoals, this can be done without padding.
        subgoals.append(subgoal)
        masks.append(mask)
    
    return BabyAITrajectoryDataset(imgs, missions, subgoals, ac, masks=masks)

def create_seq2seq_dt_contrastive_dataset(demos, seed, max_mission_len=None, max_subgoals_len=None, skip=1):
    missions, imgs, ac, subgoals, masks, next_imgs = [], [], [], [], [], []
    for demo in demos:
        if len(demo[1]) < skip+1:
            continue

        # The image sequence is shape (L, C, H, W)
        imgs.append(demo[1][:-1]) # all but the last image
        ac.append(demo[3])

        next_image = demo[1][skip:]
        if next_image.shape[0] < imgs[-1].shape[0]: # imgs[-1] is what we just appended
            # pad out with the last frame
            pad = np.expand_dims(next_image[-1], 0).repeat(imgs[-1].shape[0] - next_image.shape[0], axis=0)
            next_image = np.concatenate((next_image, pad), axis=0)
        assert next_image.shape[0] == imgs[-1].shape[0]
        next_imgs.append(next_image) # Done with dataset creation. Now we need to add the skip factor.

        # Pad the mission out to the entire length.
        mission_pad = np.zeros(max_mission_len, dtype=np.int32)
        mission_pad[:len(demo[0])] = demo[0]
        missions.append(mission_pad)

        # We need to add a start generation token to the subgoals
        subgoal = np.concatenate((np.array([WORD_TO_IDX['END_MISSION']], dtype=np.int32), demo[4][0]), axis=0)
        mask = np.concatenate((demo[5][0:1], demo[5]), axis=0) # Re-concatenate the first row mask for the subgoal.
        # Construct the subgoals, this can be done without padding.
        subgoals.append(subgoal)
        masks.append(mask)
    
    return BabyAITrajectoryDataset(imgs, missions, subgoals, ac, masks=masks, next_images=next_imgs)

def create_inverse_model_dataset(demos, seed, max_mission_len=None, max_subgoals_len=None):
    imgs, next_imgs = [], []
    invs, next_invs = [], []
    ac = []

    for demo in demos:
        ep_imgs, ep_invs, ep_ac = demo[1], demo[2], demo[3]
        imgs.append(ep_imgs[:-1]) # Get all but the last ones
        next_imgs.append(ep_imgs[1:])
        ac.append(ep_ac)
        if not ep_invs is None:
            invs.append(ep_invs[:-1])
            next_invs.append(ep_invs[1:])
    # Now concatenate everything to form the dataset
    imgs = np.concatenate(imgs, axis=0)
    next_imgs = np.concatenate(next_imgs, axis=0)
    ac = np.concatenate(ac, axis=0)
    states = {'image': imgs}
    next_states = {'image': next_imgs}
    if len(invs) > 0:
        states['inventory'] = np.concatenate(invs, axis=0)
        next_states['inventory'] = np.concatenate(next_invs, axis=0)
    return Dataset(states=states, actions=ac, next_states=next_states)

def create_dataset(path, dataset_type, env_name, seed, n_episodes, fully_obs=True, **kwargs):
    demos = generate_demos(env_name, seed, n_episodes, fully_obs=fully_obs, **kwargs)
    if not isinstance(dataset_type, list):
        dataset_type = [dataset_type]
    for dtype in dataset_type:
        dataset_fn = {
            "bc" : create_bc_dataset,
            "gpt" : create_gpt_dataset,
            "seq2seq" : create_seq2seq_dataset,
            "gpt_bc" : create_gpt_bc_dataset,
            "seq2seq_bc": create_seq2seq_bc_dataset,
            "pretrained_bc": create_pretrained_bc_dataset,
            "seq2seq_dt": create_seq2seq_dt_dataset,
            "inverse": create_inverse_model_dataset,
            "seq2seq_dt_moco": create_seq2seq_dt_contrastive_dataset
        }[dtype]
        if dtype != "seq2seq_dt_moco":
            del kwargs['skip'] # remove skip if we aren't going to use it.

        dataset = dataset_fn(demos, seed, **kwargs)
        dataset.save(path + "_" + dtype)

