import numpy as np
import torch


from src.policies.policy import Policy


class GDTPolicy(Policy):
    def __init__(
            self,
            model,
            bcq_model,
         
            observation_dim,
            action_dim,
            discount,
            bs=1,
            max_history=0,
            device='cuda',
    ):
        super().__init__()
        self.model = model
        self.bcq_model = bcq_model
        self.observation_dim = observation_dim
        self.action_dim = action_dim
        self.discount = discount
        self.bs = bs
        self.max_history = max_history
        self.device = device


       
        self.observations = []
        self.subgoals = []
        self.actions = []

    @torch.inference_mode()
    def forward(self, observation):
        if len(observation.shape) == 1:
            assert(self.bs == 1)
            no_bs = True
            observation = observation[None]
        else:
            no_bs = False

   

        self.observations.append(observation)
        self.observations = self.observations[-self.max_history:]

        bcq_history = {
            'observations': torch.Tensor(np.stack(self.observations, axis=1)).to(device=self.device)
        }

        bcq_outputs = self.bcq_model(bcq_history)

        subgoal = bcq_outputs['subgoals'][:, -1].cpu().numpy() * 1.05
        # print(observation.shape)

        # # print('bcq_outputs: ', subgoal, subgoal.shape)
        # print('original: ', subgoal)
        # print('after: ', subgoal * 0.1)
        # subgoal = subgoal * 0
        self.subgoals.append(subgoal)
        self.subgoals = self.subgoals[-self.max_history:]

        self.actions.append(np.zeros((self.bs, self.action_dim)))
        self.actions = self.actions[-self.max_history:]
        # print('command subgoals: ', np.stack(self.subgoals, axis=1).shape)

        history = {
                'subgoals': torch.Tensor(np.stack(self.subgoals, axis=1)).to(device=self.device),
                'observations': torch.Tensor(np.stack(self.observations, axis=1)).to(device=self.device),
                'actions': torch.Tensor(np.stack(self.actions, axis=1)).to(device=self.device),
            }
        print('history: ', history)

        outputs = self.model(history)

        action = outputs['actions'][:, -1].cpu().numpy()
        self.actions[-1] = action

        if no_bs:
            return action[0], bcq_outputs['subgoals'][:, -1].cpu().numpy()
        else:
            return action, bcq_outputs['subgoals'][:, -1].cpu().numpy()

    @torch.inference_mode()
    def update_context(self, observation, action, reward):
        if len(reward.shape) == 0:
            assert(self.bs == 1)
            reward = reward[None]

    @torch.inference_mode()
    def reset(self):
        self.subgoals = []
        self.observations = []
        self.actions = []

