import gymnasium as gym
import numpy as np
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class BaseEnv(gym.Env):
    state_dim: int
    action_dim: int

    def reset(self):
        raise NotImplementedError

    def transit(self, state, action):
        raise NotImplementedError

    def step(self, action):
        raise NotImplementedError

    def render(self, mode="human"):
        pass

    def deploy_eval(self, ctrl) -> tuple:
        return self.deploy(ctrl)

    def deploy(self, ctrl):
        ob = self.reset()
        obs = []
        acts = []
        next_obs = []
        rews = []
        done = False

        while not done:
            act = ctrl.act(ob)

            obs.append(ob)
            acts.append(act)

            ob, rew, done, _, _ = self.step(act)

            rews.append(rew)
            next_obs.append(ob)

        obs = np.array(obs)
        acts = np.array(acts)
        next_obs = np.array(next_obs)
        rews = np.array(rews)

        return obs, acts, next_obs, rews

    def deploy2(self, ctrl, batch: torch.Tensor | None = None) -> torch.Tensor:
        if batch is None:
            batch = torch.zeros((0, 2 * self.state_dim + self.action_dim + 1), device=device)

        s = self.reset()
        done = False

        while not done:
            a = ctrl.act(batch, torch.tensor(s, device=device))
            s_next, r, done, _, _ = self.step(a)

            new_line = torch.cat(
                (
                    torch.tensor(s, device=device),
                    torch.tensor(a, device=device),
                    torch.tensor(s_next, device=device),
                    torch.tensor(r, device=device),
                ),
                dim=-1,
            )
            batch = torch.cat((new_line[None, :], batch), dim=1)

            if batch.shape[0] > ctrl.horizon:
                batch = batch[-ctrl.horizon :]

            s = s_next

        return batch
