import numpy as np
import scipy.stats
import torch
from torchvision.transforms import transforms
import imageio

from bandit_env import Controller
from darkroom_env import DarkroomEnvVec

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class MiniworldOptPolicy(Controller):
    def __init__(self, env, batch_size=1, save_video=False, filename_template=''):
        super().__init__()
        self.env = env
        self.batch_size = batch_size
        self.save_video = save_video
        self.filename_template = filename_template
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            transforms.Resize((25, 25), antialias=True),
        ])

    def reset(self):
        return

    def act(self, state, pose, angle):
        actions = []
        for i in range(self.batch_size):
            actions.append(self.env.envs[i].opt_a(state[i], pose[i], angle[i]))
        actions = np.array(actions)

        zeros = np.zeros((self.batch_size, self.env.action_space.n))
        zeros[np.arange(self.batch_size), actions] = 1
        return zeros


class MiniworldRandCommit(MiniworldOptPolicy):
    def __init__(self, env, batch_size=1):
        super().__init__(env, batch_size=batch_size)
        self.reset()

    def reset(self):
        self.found_goal = np.zeros(self.batch_size)

    def set_batch(self, batch):
        self.batch = batch
        rs = batch['rollin_rewards'].cpu().detach().numpy().reshape(self.batch_size, -1)

        mask = (self.found_goal == 0)
        self.found_goal[mask] = np.any(rs == 1, axis=1)[mask]

    def act(self, state, pose, angle):
        optimal_action = np.argmax(super().act(state, pose, angle), axis=-1)
        random_action = np.random.randint(self.env.action_space.n, size=self.batch_size)
        actions = np.where(self.found_goal, optimal_action, random_action)

        zeros = np.zeros((self.batch_size, self.env.action_space.n))
        zeros[np.arange(self.batch_size), actions] = 1
        return zeros


class MiniworldRandPolicy(MiniworldOptPolicy):
    def __init__(self, env, batch_size=1):
        super().__init__(env, batch_size=batch_size)

    def act(self, state, pose, angle):
        actions = np.random.randint(self.env.action_space.n, size=self.batch_size)
        zeros = np.zeros((self.batch_size, self.env.action_space.n))
        zeros[np.arange(self.batch_size), actions] = 1
        return zeros


class MiniworldTransformerController(Controller):
    def __init__(self, model, batch_size=1, sample=False, save_video=False, filename_template=''):
        self.model = model
        self.du = 4
        self.H = model.H
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            transforms.Resize((25, 25), antialias=True),
        ])
        self.sample = sample
        self.temp = 1.0
        self.batch_size = batch_size
        self.save_video = save_video
        self.filename_template = filename_template

    def act(self, state, pose, angle):
        states = np.array(state)
        if self.batch_size == 1:
            states = states[None, :]

        assert len(states.shape) == 4
        states = [self.transform(state) for state in states]
        states = torch.stack(states).float().to(device)
        assert states.shape[1] == 3
        assert states.shape[2] == 25
        assert states.shape[3] == 25
        self.batch['states'] = states

        if self.batch_size == 1:
            pose = [pose]
            angle = [angle]

        self.batch['poses'] = torch.tensor(np.array(pose)).float().to(device)
        self.batch['angles'] = torch.tensor(np.array(angle)).float().to(device)

        actions = self.model(self.batch).cpu().detach().numpy()
        if self.batch_size == 1:
            actions = actions[0]

        if self.sample:
            if self.batch_size > 1:
                action_indices = []
                for idx in range(self.batch_size):
                    probs = scipy.special.softmax(actions[idx] / self.temp)
                    action_indices.append(np.random.choice(np.arange(self.du), p=probs))
            else:
                probs = scipy.special.softmax(actions / self.temp)
                action_indices = [np.random.choice(np.arange(self.du), p=probs)]
        else:
            action_indices = np.argmax(actions, axis=-1)

        actions = np.zeros((self.batch_size, self.du))
        actions[np.arange(self.batch_size), action_indices] = 1.0
        if self.batch_size == 1:
            actions = actions[0]
        return actions


class MiniworldEnvVec(DarkroomEnvVec):

    def __init__(self, envs):
        super().__init__(envs)
        self.action_space = envs[0].action_space

    def reset(self):
        return [env.reset()[0] for env in self._envs]

    def step(self, actions):
        next_obs, rews, dones = [], [], []
        for action, env in zip(actions, self._envs):
            next_ob, rew, done, _, _ = env.step(action)
            next_obs.append(next_ob)
            rews.append(rew)
            dones.append(done)
        return next_obs, rews, dones, _, {}

    def opt_a(self, x):
        return [env.opt_a(x) for env in self._envs]

    def deploy(self, ctrl, include_partial_hist=False, grow_context=False):
        state = self.reset()
        pose = [env.agent.pos[[0, -1]] for env in self._envs]
        angle = [env.agent.dir_vec[[0, -1]] for env in self._envs]

        xs = []
        poses = []
        angles = []
        xps = []
        next_poses = []
        next_angles = []
        us = []
        rs = []
        done = False

        if ctrl.save_video:
            images = [[] for _ in range(self.num_envs)]

        while not done:
            u = ctrl.act(state, pose, angle)

            state_tensor = torch.stack([ctrl.transform(s) for s in state])
            xs.append(state_tensor)
            poses.append(pose)
            angles.append(angle)
            us.append(u)

            state, r, done, _, _ = self.step(np.argmax(u, axis=-1))
            pose = [env.agent.pos[[0, -1]] for env in self._envs]
            angle = [env.agent.dir_vec[[0, -1]] for env in self._envs]
            done = all(done)

            rs.append(r)
            next_state_tensor = torch.stack([ctrl.transform(s) for s in state])
            xps.append(next_state_tensor)
            next_poses.append(pose)
            next_angles.append(angle)

            if ctrl.save_video:
                imgs = [
                    env.unwrapped.render(goal_text=True, action=u)
                    for env, u in zip(self._envs, np.argmax(u, axis=-1))]
                for i, img in enumerate(imgs):
                    images[i].append(img)

            if include_partial_hist:
                # TODO: Update rollin_poses, rollin_angles, rollin_next_poses, and rollin_next_angles
                # new_x = torch.tensor(np.array(xs[-1])[:, None, :]).float().to(device)
                new_x = state_tensor[:, None].float().to(device)
                new_pose = torch.tensor(np.array(poses[-1])[:, None, :]).float().to(device)
                new_angles = torch.tensor(np.array(angles[-1])[:, None, :]).float().to(device)
                new_u = torch.tensor(np.array(us[-1])[:, None, :]).float().to(device)
                # new_xp = torch.tensor(np.array(xps[-1])[:, None, :]).float().to(device)
                new_xp = next_state_tensor[:, None].float().to(device)
                new_next_pose = torch.tensor(np.array(next_poses[-1])[:, None, :]).float().to(device)
                new_next_angle = torch.tensor(np.array(next_angles[-1])[:, None, :]).float().to(device)
                new_r = torch.tensor(np.array(r)[:, None]).float().to(device)
                if len(ctrl.batch['rollin_rewards'].shape) == 3:
                    new_r = new_r[:, :, None]

                if grow_context:
                    new_rollin_xs = torch.cat((ctrl.batch['rollin_obs'], new_x), axis=1)
                    new_rollin_poses = torch.cat((ctrl.batch['rollin_poses'], new_pose), axis=1)
                    new_rollin_angles = torch.cat((ctrl.batch['rollin_angles'], new_angles), axis=1)
                    new_rollin_us = torch.cat((ctrl.batch['rollin_actions'], new_u), axis=1)
                    new_rollin_xps = torch.cat((ctrl.batch['rollin_next_obs'], new_xp), axis=1)
                    new_rollin_next_poses = torch.cat((ctrl.batch['rollin_next_poses'], new_next_pose), axis=1)
                    new_rollin_next_angles = torch.cat((ctrl.batch['rollin_next_angles'], new_next_angle), axis=1)
                    new_rollin_rs = torch.cat((ctrl.batch['rollin_rewards'], new_r), axis=1)
                else:
                    new_rollin_xs = torch.cat((ctrl.batch['rollin_obs'][:, 1:], new_x), axis=1)
                    new_rollin_poses = torch.cat((ctrl.batch['rollin_poses'][:, 1:], new_pose), axis=1)
                    new_rollin_angles = torch.cat((ctrl.batch['rollin_angles'][:, 1:], new_angles), axis=1)
                    new_rollin_us = torch.cat((ctrl.batch['rollin_actions'][:, 1:], new_u), axis=1)
                    new_rollin_xps = torch.cat((ctrl.batch['rollin_next_obs'][:, 1:], new_xp), axis=1)
                    new_rollin_next_poses = torch.cat((ctrl.batch['rollin_next_poses'][:, 1:], new_next_pose), axis=1)
                    new_rollin_next_angles = torch.cat((ctrl.batch['rollin_next_angles'][:, 1:], new_next_angle), axis=1)
                    new_rollin_rs = torch.cat((ctrl.batch['rollin_rewards'][:, 1:], new_r), axis=1)

                batch = {
                    'rollin_obs': new_rollin_xs,
                    'rollin_poses': new_rollin_poses,
                    'rollin_angles': new_rollin_angles,
                    'rollin_actions': new_rollin_us,
                    'rollin_next_obs': new_rollin_xps,
                    'rollin_next_poses': new_rollin_next_poses,
                    'rollin_next_angles': new_rollin_next_angles,
                    'rollin_rewards': new_rollin_rs,
                    # 'positions': ctrl.batch['positions'],
                }
                ctrl.set_batch(batch)

        if ctrl.save_video:
            images = np.array(images)
            for i in range(self.num_envs):
                imageio.mimsave(ctrl.filename_template(env_id=i), images[i])

        return (
            torch.stack(xs, axis=1),
            np.stack(poses, axis=1),
            np.stack(angles, axis=1),
            np.stack(us, axis=1),
            torch.stack(xps, axis=1),
            np.stack(next_poses, axis=1),
            np.stack(next_angles, axis=1),
            np.stack(rs, axis=1),
        )
