from typing import Any, Callable, Optional

import gym
import numpy as np
import omegaconf
import torch

import common.cdil.models
import d4rl
from common.cond.transformer_modules import (TransformerPredictor,
                                             generate_square_subsequent_mask)
from common.dail.models import DAILAgent
from common.ours.models import Policy, interpret_discrete_actions
from common.utils.metaworld_utils import MW_TASKS
from common.utils.process_dataset import (ActionConverter,
                                          ObservationConverter,
                                          PairedTrajDataset)


def get_success(obs, target, env_id, checkpt: bool = False):
    # If checkpoint, tolerance will be increased.
    if "maze2d" in env_id:
        threshold = 0.3 if not checkpt else 0.5
    else:  # ant, 4x scale, looser goal
        threshold = 1.2 if not checkpt else 2.0
    return np.linalg.norm(obs[:2] - target) <= threshold


def get_action_dim(env_id: str):
    if "maze2d" in env_id:
        return 2
    elif "point" in env_id:
        return 2
    elif "ant" in env_id:
        return 8
    elif "Lift" in env_id:
        return 4
    elif "Stack" in env_id:
        return 4
    elif env_id in MW_TASKS:
        return 4
    else:
        raise ValueError(f"{env_id} is not valid env_id.")


class BasePolicy(object):

    def __init__(self, **kwargs):
        pass

    def __call__(self, obs: np.ndarray, task_id: int):
        return self.forward(obs, task_id)

    def init(self, task_id: int):
        """It should be called at the beginning of each episode.
        """
        pass

    def forward(self, obs: np.ndarray, task_id: int):
        pass


class OursPolicy(BasePolicy):

    def __init__(
        self,
        model: Policy,
        env_id: str,
        domain_id: int,
        obs_converter: ObservationConverter,
        action_converter: ActionConverter,
        device: str = "cuda:0",
    ):
        super().__init__()
        self.model = model.to(device)
        self.model = self.model
        self.env_id = env_id
        self.domain_id = domain_id
        self.obs_converter = obs_converter
        self.action_converter = action_converter
        self.action_dim = get_action_dim(env_id)
        self.device = device

        self.prev_action = None

    def forward(
        self,
        obs: np.ndarray,
        task_id: int,
    ):
        obs = self.obs_converter(obs[None])[0]

        if len(obs) < self.model.state_dim:
            obs = np.hstack((obs, np.zeros(
                (self.model.state_dim - len(obs), ))))

        task_id_onehot = torch.eye(self.model.cond_dim)[task_id]

        obs = torch.from_numpy(obs).unsqueeze(0).float().to(self.device)
        task_id_onehot = task_id_onehot.unsqueeze(0).float().to(self.device)
        domain_id_onehot = torch.eye(
            self.model.domain_dim)[self.domain_id].unsqueeze(0).float().to(
                self.device)

        with torch.no_grad():
            act = self.model(obs, task_id_onehot,
                             domain_id_onehot)[0].squeeze()
        act = act.cpu().numpy()[:self.action_dim]

        if self.model.discrete:
            assert len(act.shape) == 2  # (action_dim, n_bins)
            act = act.argmax(axis=-1)
            act = interpret_discrete_actions(act,
                                             n_bins=self.model.discrete_bins)

        act = self.action_converter.inv(action=act[None])[0]

        return act


class DAILPolicy(BasePolicy):

    def __init__(
        self,
        model: DAILAgent,
        env_id: str,
        domain_id: int,
        obs_converter: ObservationConverter,
        action_converter: ActionConverter,
        device: str = "cuda:0",
    ):
        super().__init__()
        self.model = model.to(device)
        self.env_id = env_id
        self.domain_id = domain_id
        self.obs_converter = obs_converter
        self.action_converter = action_converter
        self.device = device

        self.prev_action = None

    def forward(
        self,
        obs: np.ndarray,
        task_id: int,
    ):
        obs = self.obs_converter(obs[None])[0]

        task_id_onehot = torch.eye(self.model.cond_dim)[task_id]

        obs = torch.Tensor(obs).unsqueeze(0).to(self.device)
        task_id_onehot = task_id_onehot.unsqueeze(0).to(self.device)

        with torch.inference_mode():
            if self.domain_id == 0:  # source domain
                obs_k = torch.cat((obs, task_id_onehot), dim=-1)
                act = self.model.source_policy(obs_k).squeeze()
            elif self.domain_id == 1:  # target domain
                source_obs = self.model.state_map(obs)
                source_obs_k = torch.cat((source_obs, task_id_onehot), dim=-1)
                source_act = self.model.source_policy(source_obs_k)
                if self.model.decode_with_state:
                    source_act = torch.cat((source_act, obs), dim=-1)
                act = self.model.action_map(source_act).squeeze()
        act = act.cpu().numpy()

        act = self.action_converter.inv(action=act[None])[0]

        return act


class CDILPolicy(BasePolicy):

    def __init__(
        self,
        model: Policy,
        env_id: str,
        domain_id: int,
        obs_converter: ObservationConverter,
        action_converter: ActionConverter,
        device: str = "cuda:0",
    ):
        super().__init__()
        self.model = model.to(device)
        self.env_id = env_id
        self.domain_id = domain_id
        self.obs_converter = obs_converter
        self.action_converter = action_converter
        self.action_dim = get_action_dim(env_id)
        self.device = device

        self.prev_action = None

    def forward(
        self,
        obs: np.ndarray,
        task_id: int,
    ):
        obs = self.obs_converter(obs[None, :])[0]

        obs = torch.Tensor(obs).unsqueeze(0).to(self.device)

        with torch.inference_mode():
            act = self.model(obs)[0].squeeze()
        act = act.cpu().numpy()

        act = self.action_converter.inv(action=act[None])[0]

        return act


class CondPolicy(BasePolicy):

    def __init__(
        self,
        model: TransformerPredictor,
        env_id: str,
        domain_id: int,
        obs_converter: ObservationConverter,
        action_converter: ActionConverter,
        traj_dataset: Optional[PairedTrajDataset] = None,
        device: str = "cuda:0",
        amp: bool = False,
        image_observation: bool = False,
        sa_demo: bool = False,
    ):
        super().__init__()
        self.model = model
        self.env_id = env_id
        self.domain_id = domain_id
        self.obs_converter = obs_converter
        self.action_converter = action_converter
        self.traj_dataset = traj_dataset
        self.device = device
        self.image_observation = image_observation
        self.sa_demo = sa_demo

        # self.agent_type, self.maze_type, _ = env_id.split("-")

        self.history = []

        env = gym.make(self.env_id)
        timeout_len = env.spec.max_episode_steps
        self.look_ahead_mask = generate_square_subsequent_mask(
            sz=timeout_len).to(device)

        self.source_demo = None
        self.amp = amp

    def init(self, task_id: int):
        self.source_demo = self.sample_source_demo(task_id=task_id)
        self.history = []

    def sample_source_demo(self, task_id: int):
        source_demo = self.traj_dataset.sample_seq_of_task_id(task_id=task_id,
                                                              source=True)
        if self.sa_demo:
            ret = np.concatenate((source_demo.obs, source_demo.actions),
                                 axis=-1)
        else:
            ret = source_demo.obs

        if self.image_observation:
            return ret, source_demo.images
        else:
            return ret

    def forward(
        self,
        obs: np.ndarray,
        task_id: int,
    ):
        self.history.append(obs)

        with torch.no_grad():
            t_obs = np.array(self.history, dtype=np.float32)[None]
            if self.image_observation:
                s_obs, s_img = self.source_demo
                s_obs = torch.from_numpy(s_obs).to(self.device)[None]
                s_img = torch.from_numpy(s_img).to(self.device)[None]
            else:
                s_obs = torch.from_numpy(self.source_demo).to(
                    self.device)[None]
            t_obs = torch.from_numpy(t_obs).to(self.device)
            with torch.cuda.amp.autocast(enabled=self.amp):
                if self.image_observation:
                    img_shape = s_img.shape
                    s_img = s_img.float()
                    s_img_encoded = self.model.source_image_encoder(
                        s_img.reshape(-1, *img_shape[2:]))
                    s_obs = torch.cat(
                        (s_obs, s_img_encoded.reshape(*img_shape[:2], -1)),
                        dim=-1)

                out = self.model(
                    source_obs=s_obs,
                    target_obs=t_obs,
                    tgt_look_ahead_mask=self.
                    look_ahead_mask[:len(self.history), :len(self.history)],
                )

        action = out.detach().cpu().numpy()[0, -1]
        action = self.action_converter.inv(action=action[None])[0]

        return action


def create_eval_policy(
    args: omegaconf.DictConfig,
    model: Any,
    env_id: str,
    domain_id: int,
    obs_converter: ObservationConverter,
    action_converter: ActionConverter,
    traj_dataset: Optional[PairedTrajDataset] = None,
):
    """Wrap a given policy with different wrapper classes depending on the type.
    """
    if isinstance(model, Policy):
        policy = OursPolicy(
            model=model,
            env_id=env_id,
            domain_id=domain_id,
            obs_converter=obs_converter,
            action_converter=action_converter,
            device=args.device,
        )
    elif isinstance(model, common.cdil.models.Policy):
        policy = CDILPolicy(
            model=model,
            env_id=env_id,
            domain_id=domain_id,
            obs_converter=obs_converter,
            action_converter=action_converter,
            device=args.device,
        )
    elif isinstance(model, DAILAgent):
        policy = DAILPolicy(
            model=model,
            env_id=env_id,
            domain_id=domain_id,
            obs_converter=obs_converter,
            action_converter=action_converter,
            device=args.device,
        )
    elif isinstance(model, TransformerPredictor):
        assert traj_dataset is not None
        policy = CondPolicy(
            model=model,
            env_id=env_id,
            domain_id=domain_id,
            traj_dataset=traj_dataset,
            obs_converter=obs_converter,
            action_converter=action_converter,
            device=args.device,
            amp=args.amp,
            image_observation=args.image_observation,
            sa_demo=args.sa_demo,
        )
    else:
        raise ValueError(f"{type(model)} is not valid model type.")

    return policy
