import torch
import torchvision.transforms as transforms

import promptrl.envs.cooking_ops as ck
from promptrl.envs.cooking_task import CookingDataset
from promptrl.envs.alfworld_task import AlfworldFillDataset, _get_alf_env_loader
from promptrl.envs.alfworld_viz import AlfworldVizDataset
from promptrl.envs.virtualhome.dataset import VirtualHomeDataset

def make_dataset(args, model, tokenizer, per_obs_tokens=lambda _: 0):
    if args.task_type.startswith('alf'):
        return make_alf_dataset(args, model, tokenizer, per_obs_tokens)
    elif args.task_type.startswith('virtualhome'):
        return make_virtualhome_dataset(args, model, tokenizer, per_obs_tokens)
    else:
        return make_cooking_dataset(args.task, args.obs_type, tokenizer, args.num_samples, args.num_eval_samples, subtask_format='listed', num_distractors=3)

def make_virtualhome_dataset(args, model, tokenizer, per_obs_tokens):
    data_mode = None
    task_types = {'forward': (0, 0, args.train_samples)}
    if args.obs_type.startswith('img'):
        init_cls = VirtualHomeDataset
        data_mode = args.prompt_arch
    else:
        raise NotImplementedError
    if args.obs_type == 'imgauxcap':
        task_types['caption'] = (1, 1, args.num_aux_samples)
    elif args.obs_type == 'imgauxinvdyn':
        task_types['invdyn'] = (1, 1, args.num_aux_samples)
    elif args.obs_type == 'imgcap':
        task_types.pop('forward')
        task_types['caption'] = (1, 1, args.num_aux_samples)
    elif args.obs_type == 'img':
        pass
    else:
        raise NotImplementedError

    train_dataset = init_cls(args.obs_type, model, tokenizer, num_samples=args.train_samples, action_burn_in=args.min_action_burn_in, mode='train', n_positions=args.max_context_tokens, per_obs_tokens=per_obs_tokens, data_mode=data_mode, task_types=task_types, seed=args.seed)

    eval_datasets = {}
    if args.eval_demo_samples > 0:
        if args.mix_id_aux_samples:
            # sample extra from id to add to train
            eval_task_types = {_task: (_id, _log_id, args.eval_demo_samples + args.num_aux_samples) for _task, (_id, _log_id, _) in task_types.items()}
        else:
            eval_task_types = {_task: (_id, _log_id, args.eval_demo_samples) for _task, (_id, _log_id, _) in task_types.items()}
        eval_datasets['demo_id'] = init_cls(args.obs_type, model, tokenizer, num_samples=args.eval_demo_samples, action_burn_in=args.min_action_burn_in, mode='novel_tasks', n_positions=args.max_context_tokens, per_obs_tokens=per_obs_tokens, data_mode=data_mode, task_types=eval_task_types, seed=args.seed)
        if args.mix_ood_aux_samples:
            # sample extra from ood to add to train
            eval_task_types = {_task: (_id, _log_id, args.eval_demo_samples + args.num_aux_samples) for _task, (_id, _log_id, _) in task_types.items()}
        else:
            eval_task_types = {_task: (_id, _log_id, args.eval_demo_samples) for _task, (_id, _log_id, _) in task_types.items()}
        eval_datasets['demo_ood'] = init_cls(args.obs_type, model, tokenizer, num_samples=args.eval_demo_samples, action_burn_in=args.min_action_burn_in, mode='novel_scenes', n_positions=args.max_context_tokens, per_obs_tokens=per_obs_tokens, data_mode=data_mode, task_types=eval_task_types, seed=args.seed)

    if args.mix_ood_aux_samples:
        add_dataset = eval_datasets.get('demo_ood', None)
        if add_dataset is None:
            add_dataset = init_cls(args.obs_type, model, tokenizer, num_samples=args.num_aux_samples, action_burn_in=args.min_action_burn_in, mode='novel_scenes', n_positions=args.max_context_tokens, per_obs_tokens=per_obs_tokens, data_mode=data_mode, task_types=eval_task_types, seed=args.seed)
        add_samples = []
        remain_samples = []
        for row in add_dataset.data:
            if len(add_samples) > args.num_aux_samples or row['task'] == 0:
                remain_samples.append(row)
            else:
                assert row['task_name'] != 'forward'
                row['log_task'] += len(task_types)
                add_samples.append(row)
        train_dataset.data.extend(add_samples)
        add_dataset.data = remain_samples
    if args.mix_id_aux_samples:
        add_dataset = eval_datasets.get('demo_id', None)
        if add_dataset is None:
            add_dataset = init_cls(args.obs_type, model, tokenizer, num_samples=args.num_aux_samples, action_burn_in=args.min_action_burn_in, mode='novel_tasks', n_positions=args.max_context_tokens, per_obs_tokens=per_obs_tokens, data_mode=data_mode, task_types=eval_task_types, seed=args.seed)
        add_samples = []
        remain_samples = []
        for row in add_dataset.data:
            if len(add_samples) > args.num_aux_samples or row['task'] == 0:
                remain_samples.append(row)
            else:
                assert row['task_name'] != 'forward'
                row['log_task'] += len(task_types)
                add_samples.append(row)
        train_dataset.data.extend(add_samples)
        add_dataset.data = remain_samples

    if args.val_split_ratio is not None:
        total_len = len(train_dataset)
        val_len = int(total_len * args.val_split_ratio)
        train_len = total_len - val_len
        train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, (train_len, val_len), generator=torch.Generator().manual_seed(args.seed))
        train_dataset.get_collator = train_dataset.dataset.get_collator
        val_dataset.get_collator = val_dataset.dataset.get_collator
        eval_datasets['val'] = val_dataset

    return train_dataset, eval_datasets

def make_alf_dataset(args, model, tokenizer, per_obs_tokens):
    data_mode = None
    task_types = {'forward': (0, 0, args.train_samples)}
    if args.obs_type.startswith('img'):
        init_cls = AlfworldVizDataset
        data_mode = args.prompt_arch
    else:
        assert args.obs_type.startswith('lang') or args.obs_type in ['oracle', 'mrcnn']
        init_cls = AlfworldFillDataset

    if args.obs_type == 'imgauxcap':
        task_types['caption'] = (1, 1, args.num_aux_samples)
    elif args.obs_type == 'imgauxinvdyn':
        task_types['invdyn'] = (1, 1, args.num_aux_samples)
    elif args.obs_type == 'imgauxgoalp':
        task_types['goalp'] = (1, 1, args.num_aux_samples)
    elif args.obs_type == 'imgauxc2':
        task_types['caption'] = (1, 1, args.num_aux_samples)
        task_types['invdyn'] = (2, 2, args.num_aux_samples)
    elif args.obs_type == 'imgauxcg':
        task_types['caption'] = (1, 1, args.num_aux_samples)
        task_types['goalp'] = (2, 2, args.num_aux_samples)
    elif args.obs_type == 'imgcap':
        task_types.pop('forward')
        task_types['caption'] = (1, 1, args.num_aux_samples)
    elif args.obs_type == 'imgadm':
        task_types.pop('forward')
        task_types['admissible'] = (1, 1, args.num_aux_samples)

    if args.task_type.endswith('pick-place'):
        train_tasks = [1]
        eval_tasks = [1]
    elif args.task_type.endswith('examine'):
        train_tasks = [2]
        eval_tasks = [2]
    elif args.task_type.endswith('clean'):
        train_tasks = [3]
        eval_tasks = [3]
    elif args.task_type.endswith('heat'):
        train_tasks = [4]
        eval_tasks = [4]
    elif args.task_type.endswith('cool'):
        train_tasks = [5]
        eval_tasks = [5]
    elif args.task_type.endswith('pick-two'):
        train_tasks = [6]
        eval_tasks = [6]
    elif args.task_type.endswith('all'):
        train_tasks = [1, 2, 3, 4, 5, 6]
        eval_tasks = [1, 2, 3, 4, 5, 6]
    elif args.task_type.endswith('cross'):
        train_tasks = [1, 2, 3, 4]
        eval_tasks = [5, 6]
    elif args.task_type.endswith('cross2'):
        train_tasks = [3, 4]
        eval_tasks = [5]
    else:
        raise NotImplementedError(f'Task {task} not implemented.')

    train_dataset = init_cls(args.obs_type, train_tasks, model, tokenizer, num_samples=args.train_samples, action_burn_in=args.min_action_burn_in, mode='train', n_positions=args.max_context_tokens, per_obs_tokens=per_obs_tokens, data_mode=data_mode, task_types=task_types, seed=args.seed, limit_frames=args.limit_frames)

    eval_datasets = {}
    if args.eval_demo_samples > 0:
        if args.mix_id_aux_samples:
            # sample extra from id to add to train
            eval_task_types = {_task: (_id, _log_id, args.eval_demo_samples + args.num_aux_samples) for _task, (_id, _log_id, _) in task_types.items()}
        else:
            eval_task_types = {_task: (_id, _log_id, args.eval_demo_samples) for _task, (_id, _log_id, _) in task_types.items()}
        eval_datasets['demo_id'] = init_cls(args.obs_type, train_tasks, model, tokenizer, num_samples=args.eval_demo_samples, action_burn_in=args.min_action_burn_in, mode='eval_in_distribution', n_positions=args.max_context_tokens, per_obs_tokens=per_obs_tokens, data_mode=data_mode, task_types=eval_task_types, seed=args.seed, limit_frames=args.limit_frames)
        if args.mix_ood_aux_samples:
            # sample extra from ood to add to train
            eval_task_types = {_task: (_id, _log_id, args.eval_demo_samples + args.num_aux_samples) for _task, (_id, _log_id, _) in task_types.items()}
        else:
            eval_task_types = {_task: (_id, _log_id, args.eval_demo_samples) for _task, (_id, _log_id, _) in task_types.items()}
        eval_datasets['demo_ood'] = init_cls(args.obs_type, train_tasks, model, tokenizer, num_samples=args.eval_demo_samples, action_burn_in=args.min_action_burn_in, mode='eval_out_of_distribution', n_positions=args.max_context_tokens, per_obs_tokens=per_obs_tokens, data_mode=data_mode, task_types=eval_task_types, seed=args.seed, limit_frames=args.limit_frames)

    if args.mix_ood_aux_samples:
        add_dataset = eval_datasets.get('demo_ood', None)
        if add_dataset is None:
            add_dataset = init_cls(args.obs_type, train_tasks, model, tokenizer, num_samples=args.num_aux_samples, action_burn_in=args.min_action_burn_in, mode='eval_out_of_distribution', n_positions=args.max_context_tokens, per_obs_tokens=per_obs_tokens, data_mode=data_mode, task_types=eval_task_types, seed=args.seed, limit_frames=args.limit_frames)
        add_samples = []
        remain_samples = []
        for row in add_dataset.data:
            if len(add_samples) > args.num_aux_samples or row['task'] == 0:
                remain_samples.append(row)
            else:
                assert row['task_name'] != 'forward'
                row['log_task'] += len(task_types)
                add_samples.append(row)
        train_dataset.data.extend(add_samples)
        add_dataset.data = remain_samples
    if args.mix_id_aux_samples:
        add_dataset = eval_datasets.get('demo_id', None)
        if add_dataset is None:
            add_dataset = init_cls(args.obs_type, train_tasks, model, tokenizer, num_samples=args.num_aux_samples, action_burn_in=args.min_action_burn_in, mode='eval_in_distribution', n_positions=args.max_context_tokens, per_obs_tokens=per_obs_tokens, data_mode=data_mode, task_types=eval_task_types, seed=args.seed, limit_frames=args.limit_frames)
        add_samples = []
        remain_samples = []
        for row in add_dataset.data:
            if len(add_samples) > args.num_aux_samples or row['task'] == 0:
                remain_samples.append(row)
            else:
                assert row['task_name'] != 'forward'
                row['log_task'] += len(task_types)
                add_samples.append(row)
        train_dataset.data.extend(add_samples)
        add_dataset.data = remain_samples

    if args.val_split_ratio is not None:
        total_len = len(train_dataset)
        val_len = int(total_len * args.val_split_ratio)
        train_len = total_len - val_len
        train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, (train_len, val_len), generator=torch.Generator().manual_seed(args.seed))
        train_dataset.get_collator = train_dataset.dataset.get_collator
        val_dataset.get_collator = val_dataset.dataset.get_collator
        eval_datasets['val'] = val_dataset

    return train_dataset, eval_datasets

def make_env(task, obs_type, tokenizer):
    if task.endswith('pick-place'):
        train_tasks = [1]
        eval_tasks = [1]
    elif task.endswith('examine'):
        train_tasks = [2]
        eval_tasks = [2]
    elif task.endswith('clean'):
        train_tasks = [3]
        eval_tasks = [3]
    elif task.endswith('heat'):
        train_tasks = [4]
        eval_tasks = [4]
    elif task.endswith('cool'):
        train_tasks = [5]
        eval_tasks = [5]
    elif task.endswith('pick-two'):
        train_tasks = [6]
        eval_tasks = [6]
    elif task.endswith('all'):
        train_tasks = [1, 2, 3, 4, 5, 6]
        eval_tasks = [1, 2, 3, 4, 5, 6]
    elif task.endswith('cross'):
        train_tasks = [1, 2, 3, 4]
        eval_tasks = [5, 6]
    elif task.endswith('cross2'):
        train_tasks = [3, 4]
        eval_tasks = [5]
    else:
        raise NotImplementedError(f'Task {task} not implemented.')

    if obs_type == 'lang':
        env_args = {
            #'train': ('AlfredTWEnv', 'oracle', 'train'),
            'id': ('AlfredTWEnv', 'oracle', 'eval_in_distribution'),
            'ood': ('AlfredTWEnv', 'oracle', 'eval_out_of_distribution'),
        }
    elif obs_type == 'oracle':
        env_args = {
            #'train': ('AlfredThorEnv', 'oracle', 'train'),
            'id': ('AlfredThorEnv', 'oracle', 'eval_in_distribution'),
            'ood': ('AlfredThorEnv', 'oracle', 'eval_out_of_distribution'),
        }
    elif obs_type == 'mrcnn':
        env_args = {
            #'train': ('AlfredThorEnv', 'mrcnn', 'train'),
            'id': ('AlfredThorEnv', 'mrcnn', 'eval_in_distribution'),
            'ood': ('AlfredThorEnv', 'mrcnn', 'eval_out_of_distribution'),
        }
    elif obs_type == 'lang-all':
        env_args = {
            #'tw_train': ('AlfredTWEnv', 'oracle', 'train'),
            'tw_id': ('AlfredTWEnv', 'oracle', 'eval_in_distribution'),
            'tw_ood': ('AlfredTWEnv', 'oracle', 'eval_out_of_distribution'),
            'mrcnn_id': ('AlfredThorEnv', 'mrcnn', 'eval_in_distribution'),
            'mrcnn_ood': ('AlfredThorEnv', 'mrcnn', 'eval_out_of_distribution'),
            'oracle_id': ('AlfredThorEnv', 'oracle', 'eval_in_distribution'),
            'oracle_ood': ('AlfredThorEnv', 'oracle', 'eval_out_of_distribution'),
        }
    elif obs_type.startswith('img'):
        env_args = {
            #'train': ('AlfworldVizEnv', 'viz_sref', 'train'),
            'id': ('AlfworldVizEnv', 'viz_sref', 'eval_in_distribution'),
            'ood': ('AlfworldVizEnv', 'viz_sref', 'eval_out_of_distribution'),
        }
    else:
        raise NotImplementedError

    _env_loader = lambda env_type, controller_type, train_eval: _get_alf_env_loader(env_type, controller_type, tokenizer, eval_tasks, 1, train_eval)
    envs = {name: _env_loader(*args) for name, args in env_args.items()}
    return envs

def make_cooking_dataset(task, obs_type, tokenizer, num_samples, num_eval_samples, subtask_format='listed', num_distractors=3):
    if task == 'novel-objs':
        train_tasks = [
            ('cook', 'potato'),
            ('cook', 'tomato')
        ]
        eval_tasks = [
            ('cook', 'lettuce'),
            ('cook', 'onion'),
        ]
        room_dim = (1, 1)
    elif task == 'novel-pos':
        train_tasks = [
            ('cook', 'potato'),
            ('cook', 'tomato'),
            ('cook', 'lettuce'),
            ('cook', 'onion'),
        ]
        eval_tasks = train_tasks
        room_dim = (1, 1)
    elif task == 'multiroom2':
        train_tasks = [
            ('cook', 'potato'),
            ('cook', 'tomato'),
            ('cook', 'lettuce'),
            ('cook', 'onion'),
        ]
        eval_tasks = train_tasks
        room_dim = (2, 1)
    elif task == 'multiroom4':
        train_tasks = [
            ('cook', 'potato'),
            ('cook', 'tomato'),
            ('cook', 'lettuce'),
            ('cook', 'onion'),
        ]
        eval_tasks = train_tasks
        room_dim = (2, 2)

    else:
        raise NotImplementedError(f'Task {task} not implemented.')

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = CookingDataset(obs_type, train_tasks, tokenizer, subtask_format=subtask_format, length=num_samples, num_distractors=num_distractors, room_dim=room_dim, transform=transform)
    eval_dataset = CookingDataset(obs_type, eval_tasks, tokenizer, subtask_format=subtask_format, length=num_eval_samples, num_distractors=num_distractors, room_dim=room_dim, transform=transform)
    return train_dataset, eval_dataset

def make_vhome_envs(task_type, partition=None):
    from promptrl.envs.virtualhome.env import make_vhome_loader
    envs = {}
    main, mode = task_type.split('-')
    assert main == 'virtualhome'
    if mode == 'all':
        envs['id'] = make_vhome_loader('id', partition)
        envs['novel_tasks'] = make_vhome_loader('novel_tasks', partition)
        envs['novel_scenes'] = make_vhome_loader('novel_scenes', partition)
    else:
        envs[mode] = make_vhome_loader(mode, partition)
    return envs

if __name__ == '__main__':
    from pprint import pprint
    rng = np.random.default_rng(7)
    task, state = init_task('cook', 'tomato', rng, 2, (2, 2))
    print(f'Task: {task}')
    pprint(state)

    print(get_lang_obs(state))
    print(get_lang_obs(state, rooms=4))

    pprint(ck.op_cook('tomato', state))

    obs = get_obs(state)
    obs.save('../windows/prompt-test-obs.png')
    obs.show()
