from offlinerl.algo.modelbase.maple_st import algo_init
from offlinerl.algo.modelbase.maple_st import AlgoTrainer as BaseTrainer
from random import random
import torch
import numpy as np

class AlgoTrainer(BaseTrainer):

    def get_s_prob_loss(self, env_batch, value_state, act_now):
        # learn state prior in gan-like style
        q1_s_prob = self.q1(value_state, act_now, env_batch['observations'])[:, -1:]
        q2_s_prob = self.q2(value_state, act_now, env_batch['observations'])[:, -1:]
        min_q_s_prob = torch.min(q1_s_prob, q2_s_prob).detach()
        argsorted_q = torch.argsort(min_q_s_prob.squeeze())
        
        len_s_prob = len(env_batch['observations'])
        if not self.args['random_partition']:
            low_q_indexes = argsorted_q[:int(len_s_prob*self.args['low_q_portion'])]
            high_q_indexes = argsorted_q[int(len_s_prob*self.args['low_q_portion']):]
        else:
            indexes = np.arange(len_s_prob)
            np.random.shuffle(indexes)
            low_q_indexes = indexes[:int(len_s_prob*self.args['low_q_portion'])]
            high_q_indexes = indexes[int(len_s_prob*self.args['low_q_portion']):]
        s_low_q  = env_batch['observations'][:, -1:][low_q_indexes]
        s_high_q = env_batch['observations'][:, -1:][high_q_indexes]
        
        # get the index of states in s_high_q that have higher uncertainty
        self.transition.requires_grad_(False)

        # action_sample = self.actor(s_high_q).sample()
        # print(act_now.shape)
        obs_action_high_q = torch.cat([s_high_q, act_now[:, -1:][high_q_indexes]], dim=-1)
        # print(s_high_q.shape)
        # print(act_now[:, -1:][high_q_indexes].shape)
        # print(obs_action_high_q.shape)
        next_obs_dists = self.transition(obs_action_high_q.squeeze(1))
        aleatoric_uncertainty = torch.max(torch.norm(next_obs_dists.stddev, dim=-1, keepdim=True), dim=0)[0]
        argsorted_uncertainty = torch.argsort(aleatoric_uncertainty.squeeze())
        # print(aleatoric_uncertainty.shape)

        len_s_high_q = len(s_high_q)
        low_uncertainty_indexes = argsorted_uncertainty[:int(len_s_high_q*self.args['low_uncertainty_portion'])]
        high_uncertainty_indexes = argsorted_uncertainty[int(len_s_high_q*self.args['low_uncertainty_portion']):]
        s_high_q_low_uncertainty = s_high_q[low_uncertainty_indexes]
        s_high_q_high_uncertainty = s_high_q[high_uncertainty_indexes]
        # print(s_low_q.shape)
        # print(s_high_q_high_uncertainty.shape)
        s_others = torch.vstack([s_low_q, s_high_q_high_uncertainty])
        # print(s_others.shape)
        
        zeros = torch.zeros(s_others.size(0), 1).to(device=self.args['device'])
        ones  = torch.ones(s_high_q_low_uncertainty.size(0), 1).to(device=self.args['device'])
        
        low_preds  = self.s_prior(s_others)[0]
        high_preds = self.s_prior(s_high_q_low_uncertainty)[0]

        loss = torch.nn.BCELoss()
        s_prob_loss = loss(torch.sigmoid(low_preds), zeros) + \
                    loss(torch.sigmoid(high_preds), ones)
                    
        return s_prob_loss