from dataclasses import dataclass
from copy import deepcopy

import PIL
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

import promptrl.envs.cooking_ops as ck

DIRTY_CHANCE = 0.5
SLICE_CHANCE = 0.5
TASKS = ['slice', 'cook']
DISTRACTORS = 5

# grid params
WIDTH = 80
HEIGHT = 80
GRID_WIDTH = 6
GRID_HEIGHT = 6
UNIT_WIDTH = WIDTH // GRID_WIDTH
UNIT_HEIGHT = HEIGHT // GRID_HEIGHT
ICON_CACHE = {}

MAX_TOKEN_LENGTH = 150
EOS_TOKEN = ck.EOS_TOKEN

def get_icon(a, state):
    name = a
    if state[a]['dirty']:
        name += '_dirty'
    if state[a]['sliced']:
        name += '_sliced'
    if state[a]['cooked']:
        name += '_cooked'

    name += '.png'
    name = 'icons/' + name
    if name in ICON_CACHE:
        return ICON_CACHE[name]
    ICON_CACHE[name] = PIL.Image.open(name).resize((UNIT_WIDTH, UNIT_HEIGHT))
    return ICON_CACHE[name]

def get_obs(state):
    grid = PIL.Image.new('RGB', size=(WIDTH, HEIGHT))
    for obj, conds in state.items():
        icon = get_icon(obj, state)
        coord_x, coord_y = conds['loc']
        grid.paste(icon, box=(UNIT_WIDTH * coord_x, UNIT_HEIGHT * coord_y))
    return grid

def _objs_to_text(objs, state):
    article = lambda n: 'an' if n[0] in {'a', 'o', 'i', 'u', 'e'} else 'a'
    if len(objs) > 1:
        obj_text = ', '.join(f'{article(o)} {o}' for o in objs[:-1])
        obj_text += f', and {article(objs[-1])} {objs[-1]}.'
    elif len(objs) == 1:
        obj_text = f'{article(objs[0])} {objs[0]}.'
    else:
        raise ValueError

    cond_text = []
    for obj in objs:
        if state[obj]['sliced'] and state[obj]['dirty']:
            cond_text.append(f'The {obj} is sliced and dirty.')
        elif state[obj]['sliced']:
            cond_text.append(f'The {obj} is sliced.')
        elif state[obj]['dirty']:
            cond_text.append(f'The {obj} is dirty.')
    cond_text = ' '.join(cond_text)

    obs = obj_text
    if len(cond_text) > 0:
        obs += ' ' + cond_text
    return obs

def get_lang_obs(state, rooms=0):
    if rooms == 0:
        text = 'You are in a kitchen. Looking around you, you see'
        objs = []
        for obj in state:
            if obj == 'agent':
                continue
            objs.append(obj)
        obs = text + ' ' + _objs_to_text(objs, state)
    else:
        text = f'You are in room {state["agent"]["room"]}.'
        room_txts = []
        for room_idx in range(rooms):
            room = chr(ord('A') + room_idx)
            objs = []
            for obj, conds in state.items():
                if obj == 'agent':
                    continue
                if conds['room'] == room:
                    objs.append(obj)
            if len(objs) == 0:
                continue
            room_txt = _objs_to_text(objs, state)
            room_txts.append(f'In room {room}, there is {room_txt}')
        obs = text + ' ' + ' '.join(room_txts)

    return obs


def init_task(task_kind, target, rng, num_distractors=5, room_dim=(1, 1)):
    if task_kind == 'slice':
        required = ['knife', 'agent']
        task = f'Slice the {target}.'
        required.append(target)
        distractors = list(rng.choice(list(set(ck.ALL_OBJS) - set(required)), num_distractors, replace=False))
    elif task_kind == 'cook':
        required = ['knife', 'stove', 'agent', 'sink']
        cooker = rng.choice(ck.COOKER)
        required.append(target)
        required.append(cooker)
        task = f'Cook the {target}.'
        distractors = list(rng.choice(list(set(ck.ALL_OBJS) - set(required) - set(ck.COOKER)), num_distractors, replace=False))
    else:
        raise NotImplementedError(f'Task {task} not implemented.')

    state = {}
    for obj in (required + distractors):
        state[obj] = deepcopy(ck.OBJ_DEFAULT_STATE)
        if obj in ck.DIRTYABLE:
            if rng.random() < DIRTY_CHANCE:
                state[obj]['dirty'] = True
        if obj in ck.SLICEABLE:
            if task_kind != 'slice' or obj != target:
                if rng.random() < SLICE_CHANCE:
                    state[obj]['sliced'] = True

    locs = list(range(GRID_WIDTH * GRID_HEIGHT))
    loc_samples = rng.choice(locs, len(state), replace=False)

    for obj, loc in zip(state, loc_samples):
        coord_x = loc // GRID_WIDTH
        coord_y = loc % GRID_HEIGHT
        state[obj]['loc'] = (coord_x, coord_y)
        room_x = coord_x // (GRID_WIDTH // room_dim[0])
        room_y = coord_y // (GRID_HEIGHT // room_dim[1])
        room_idx = room_y * room_dim[1] + room_x
        state[obj]['room'] = chr(room_idx + ord('A'))

    return task, state

class CookingDataset(Dataset):
    def __init__(self, obs_type, tasks, tokenizer, length=100000, subtask_format='plain', num_distractors=5, room_dim=(1, 1), seed=42, transform=transforms.ToTensor()):
        self.obs_type = obs_type
        self.tasks = tasks
        self.tokenizer = tokenizer
        self.num_distractors = num_distractors
        self.room_dim = room_dim
        self.rooms = room_dim[0] * room_dim[1]
        self.seed = seed
        self.subtask_format = subtask_format
        self.rng = np.random.default_rng(seed)

        self.tokenizer.eos_token = self.tokenizer('Done')['input_ids'][0]
        self.max_length = MAX_TOKEN_LENGTH

        self.transform = transform

        self._init_tasks(length)

    def _init_tasks(self, length):
        self.buffer = []
        for i in range(length):
            task_kind, target_obj = self.rng.choice(self.tasks)
            goal_plain, state = init_task(task_kind, target_obj, self.rng, num_distractors=self.num_distractors, room_dim=self.room_dim)
            goal_v = []
            for word in goal_plain.split():
                word = word.lower().replace('.', '')
                goal_v.append(ck.VOCAB_TO_IDX[word])
            subtask_seq = ck.OPTION_TO_OP[task_kind](target_obj, deepcopy(state))

            vec_state = ck.state_to_vec(state, rooms=self.rooms)
            vec_code = ck.state_to_vec(state, target_obj=target_obj, simplified=True)

            goal, subtask, all_text = self.format_text(goal_plain, subtask_seq)

            self.buffer.append({
                '_state_info': state,
                'goals': goal,
                'goals_v': np.array(goal_v, dtype=np.int),
                'subtasks': subtask,
                'all_text': all_text,
                'states': vec_state,
                'codes': vec_code
            })

    def __len__(self):
        return len(self.buffer)

    def _get_obs(self, sample):
        if self.obs_type == 'img':
            return sample['imgs'], None
        elif self.obs_type == 'state':
            return sample['states'], None
        elif self.obs_type == 'simple':
            return sample['codes'], None
        elif self.obs_type == 'lang':
            return sample['lang_obs']['input_ids'], sample['lang_obs']['attention_mask']
        else:
            raise NotImplementedError(f'Obs type {self.obs_type} not implemented')

    def get_collator(self):
        def cooking_collator(batch):
            imgs = torch.stack([sample['imgs'] for sample in batch])
            states = torch.from_numpy(np.stack([sample['states'] for sample in batch]))
            codes = torch.from_numpy(np.stack([sample['codes'] for sample in batch]))
            goals_v = torch.from_numpy(np.stack([sample['goals_v'] for sample in batch]))

            goals = self.tokenizer([sample['goals'] for sample in batch], return_tensors="pt", padding=True)
            subtasks = self.tokenizer([sample['subtasks'] for sample in batch], return_tensors="pt", padding=True)
            all_text = self.tokenizer([sample['all_text'] for sample in batch], return_tensors="pt", padding=True)
            lang_obs = self.tokenizer([sample['lang_obs'] for sample in batch], return_tensors="pt", padding=True)

            batch = {
                'imgs': imgs,
                'lang_obs': lang_obs,
                'goals': goals,
                'goals_v': goals_v,
                'subtasks': subtasks,
                'all_text': all_text,
                'states': states,
                'codes': codes
            }
            batch['obs'], batch['obs_mask'] = self._get_obs(batch)
            return batch

        return cooking_collator

    def format_text(self, goal_plain, subtask_plain):
        if self.subtask_format == 'plain':
            subtasks = ' '.join(subtask_plain)
            goal = goal_plain
            all_text = goal + ' ' + subtasks
        elif self.subtask_format == 'formatted':
            subtasks = ' '.join(subtask_plain)
            goal = 'How to ' + goal_plain.lower()[:-1] + ':'
            all_text = goal + ' ' + subtasks
        elif self.subtask_format == 'listed':
            subtasks = ' '.join(f'{i+1}. {subtask}' for i, subtask in enumerate(subtask_plain))
            goal = 'How to ' + goal_plain.lower()[:-1] + ':'
            all_text = goal + ' ' + subtasks
        else:
            raise NotImplementedError(f'Subtask format {self.subtask_format} not implemented')

        return goal, subtasks, all_text

    def __getitem__(self, idx):
        result = deepcopy(self.buffer[idx])
        result['imgs'] = np.array(get_obs(result['_state_info']))
        result['imgs'] = self.transform(result['imgs'])

        result['lang_obs'] = get_lang_obs(result['_state_info'], rooms=self.rooms)
        return result

if __name__ == '__main__':
    from pprint import pprint
    rng = np.random.default_rng(7)
    task, state = init_task('cook', 'tomato', rng, 2, (2, 2))
    print(f'Task: {task}')
    pprint(state)

    print(get_lang_obs(state))
    print(get_lang_obs(state, rooms=4))

    pprint(ck.op_cook('tomato', state))

    obs = get_obs(state)
    obs.save('../windows/prompt-test-obs.png')
    obs.show()
