import argparse
import os
from collections import namedtuple
from functools import partial as bind

import cv2
import gymnasium as gym
import numpy as np
import ruamel.yaml as yaml
from omegaconf import OmegaConf
import torch

import dreamerv3
import embodied
from embodied import OrderedSlotsBatchWrapper
import shapes_envs.shapes2d
from embodied.envs.from_gymnasium import FromGymnasium
from ocr.slate.slate import SLATE
from ocr.tools import SlotExtractor


def load_slot_extractor(ocr_config_path, obs_size, checkpoint_path):
    config_ocr = OmegaConf.load(ocr_config_path)
    config_env = namedtuple('EnvConfig', ['obs_size', 'obs_channels'])(obs_size, 3)
    slate = SLATE(config_ocr, config_env, observation_space=None, preserve_slot_order=True)
    slate = slate.cuda()

    state_dict = torch.load(checkpoint_path)["ocr_module_state_dict"]
    slate._module.load_state_dict(state_dict)
    slate.eval()

    for param in slate.parameters():
        param.requires_grad = False

    slot_extractor = SlotExtractor(slate, device='cuda')

    def get_slots_dim():
        return config_ocr.slotattr.num_slots, config_ocr.slotattr.slot_size

    slot_extractor.get_slots_dim = get_slots_dim

    return slot_extractor


class ResizeWrapper(gym.ObservationWrapper):
    def __init__(self, env, image_size):
        super().__init__(env)
        self._image_size = image_size
        assert len(env.observation_space.shape) == 3
        assert env.observation_space.shape[2] == 3
        assert env.observation_space.dtype == np.uint8
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=(self._image_size, self._image_size, 3), dtype=np.uint8
        )

    def observation(self, observation):
        return cv2.resize(observation, dsize=(self._image_size, self._image_size), interpolation=cv2.INTER_CUBIC)

    def step(self, action):
        observation, reward, terminated, truncated, info = super().step(action)
        masks = info['masks']
        info['masks'] = np.asarray(np.asarray([self.observation(mask) for mask in masks]) > 0, dtype=np.uint8)
        return observation, reward, terminated, truncated, info

    def reset(self, *args, **kwargs):
        observation, info = super().reset(*args, **kwargs)
        masks = info['masks']
        info['masks'] = np.asarray(np.asarray([self.observation(mask) for mask in masks]) > 0, dtype=np.uint8)
        return observation, info


def make_shapes2d(env_id, image_size, config, seed):
    env = gym.make(env_id, return_state=False, seed=seed)
    env = ResizeWrapper(env, image_size)
    env = FromGymnasium(env, obs_key='image', masks_key='masks')  # Or obs_key='vector'.
    env = dreamerv3.wrap_env(env, config)

    return env


def make_batch_env(env_id, image_size, config, seed, n_envs, parallel=None):
    envs = []
    if parallel is None:
        parallel = config.envs.parallel

    def create_env_fn(env_seed=None):
        return lambda: make_shapes2d(env_id, image_size, config, seed=env_seed)

    for i in range(n_envs):
        env_fn = create_env_fn(seed + i)
        if parallel != 'none':
            env_fn = bind(embodied.Parallel, env_fn, parallel)

        envs.append(env_fn)

    envs = [env_fn() for env_fn in envs]
    return embodied.BatchEnv(envs, parallel=(parallel != 'none'))


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, required=True)
    parser.add_argument('--env_id', type=str, required=True)
    parser.add_argument('--image_size', type=int, required=True)
    parser.add_argument('--n_envs', default=1, type=int)
    parser.add_argument('--checkpoint_path', type=str, required=False)
    parser.add_argument('--param_prefixes', nargs='*', type=str, required=False)
    parser.add_argument('--ocr_config_path', type=str, required=True)
    parser.add_argument('--ocr_checkpoint_path', type=str, required=True)
    parser.add_argument('--ordered', default=True, action=argparse.BooleanOptionalAction)
    return parser.parse_args()


def main():
    args = parse_args()
    print('Args:', vars(args))
    checkpoint_path = args.checkpoint_path
    param_prefixes = args.param_prefixes
    configs = yaml.YAML(typ='safe').load(embodied.Path(args.config_path).read())
    config = embodied.Config(configs['defaults'])
    config = config.update(dreamerv3.configs['medium'])

    logdir = embodied.Path(config.logdir)
    step = embodied.Counter()
    outputs = [
        embodied.logger.TerminalOutput(),
        embodied.logger.JSONLOutput(logdir, 'metrics.jsonl'),
    ]
    if config.wandb.project is not None:
        os.makedirs(config.wandb.dir, exist_ok=True)
        outputs.append(embodied.logger.WandBOutput(config))

    logger = embodied.Logger(step, outputs)

    slot_extractor = load_slot_extractor(args.ocr_config_path, args.image_size, args.ocr_checkpoint_path)
    env = make_batch_env(args.env_id, image_size=args.image_size, config=config, seed=config['seed'],
                         n_envs=args.n_envs)
    env = OrderedSlotsBatchWrapper(env, slot_extractor, is_ordered=args.ordered)

    eval_env = make_batch_env(args.env_id, image_size=args.image_size, config=config, seed=config['seed'] + args.n_envs,
                              n_envs=1, parallel='none')
    eval_env = OrderedSlotsBatchWrapper(eval_env, slot_extractor, is_ordered=args.ordered)

    agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
    replay = embodied.replay.Uniform(
        config.batch_length, config.replay_size, logdir / 'replay')
    args = embodied.Config(
        **config.run, logdir=config.logdir,
        batch_steps=config.batch_size * config.batch_length)
    if checkpoint_path is None:
        checkpoint_path = os.path.join(config.wandb.project, config.wandb.name)
    os.makedirs(checkpoint_path, exist_ok=True)
    embodied.run.train_eval_no_replay(agent, train_env=env, eval_env=eval_env, train_replay=replay, logger=logger,
                                      args=args, checkpoint_path=checkpoint_path, param_prefixes=param_prefixes)


if __name__ == "__main__":
    main()
