import numpy as np
import torch
from .networks_pytorch import (
    TanhNormalPolicy,
    ValueNetwork,
    MaskedValue
)
from .disdain import EnsembleOfDiscriminators
from collections import defaultdict
import time
import matplotlib.pyplot as plt

import wandb
from .doi_abstract import AbstractDOI
from .utils import  stable_softmax
import functools
import tqdm

EPS = 1e-16

class DOI(AbstractDOI):

    def _process_config(self, config):

        self._disc_type = config[
            "disc_type"
        ]  # bc (behavior cloning) or doi (diverse offline imitation)
        self._gamma = config["gamma"]
        self._env_name = config["env_name"]
        self._total_iterations = config["total_iterations"]
        self._use_policy_entropy_constraint = config["use_policy_entropy_constraint"]
        self._target_entropy = config["target_entropy"]
        self._hidden_sizes = config["hidden_sizes"]
        self._batch_size = config["batch_size"]
        self._f = config["f"]  # f-divergence function
        self._lr = config["lr"]  # learning rate for value and discriminator
        self._actor_lr = config["actor_lr"]  # learning rate for actor
        self._disc_lr  = config["disc_lr"]  # learning rate for discriminator
        self._v_l2_reg = config["v_l2_reg"]  # l2 regularization for value network

        # our additions
        self._num_skills = config["num_skills"]  # number of skills
        self._lambda_lr = config["lambda_lr"]  # learning rate for lambda
        self._epsilon = config["epsilon"]  # the constraint on the f-divergence
        self._init_frac = config["init_frac"]  # initial value of sigmoid(lambda)
        self._tau = config[
            "tau"
        ]  # for soft update of target parameters (not needed right now)
        self._learn_lambda = config["learn_lambda"]  # if we want to learn lambda or not
        self._w_disdain = config[
            "w_disdain"
        ]  # weighting on the disdain part of disc reward
        self._w_disc_reward = config[
            "w_disc_reward"
        ]  # weighting on the discriminator reward
        self._burnin_steps = config[
            "burnin_steps"
        ]  # number of steps that we don't learn the lambdas

        self._temperature = config["temperature"]  # temperature for softmax transform

        self._assertions = config['assertions'] # if we want to assert things or not

        self._v_stage_steps = config['v_stage_steps'] # number of steps to train the value function for
        self._l_stage_freq  = config['l_stage_freq'] # frequency of training lambda
        self._l_stage_batches = config['l_stage_batches'] # number of batches to train lambda with (for better phi estimate)
        self._pi_stage_steps = config['pi_stage_steps'] # number of steps to train the policy for
        self._d_stage_steps = config['d_stage_steps'] # number of steps to train the discriminator for

        self._div_lambda = config['div_lambda'] # after burnin steps we set lambda to this
        self._smodice_div = config['smodice_div'] # if we want to use smodice divergence
        self._disc_entropy_penalty = config['disc_entropy_penalty'] # entropy penalty for policy


        self._multiple_experts = config['multiple_experts'] # if we want to use multiple expert w_e
        self._temp_disc = config['temp_disc'] # temperature for discriminator
        self._masked_value_net = config['masked_value_net'] # if we want to mask the value network
        self._disc_lr = config['disc_lr'] # learning rate for discriminator

        self._w_disdain = config['w_disdain'] # weighting on the disdain part of disc reward
        self._disc_hidden = config['disc_hidden'] # hidden size for discriminator
        self._p_drop = config['p_drop'] # dropout probability for discriminator
        self._expert_only = config['expert_only']
        self._reward_type = config.get('reward_type', 'log_w_e')

        self._untruncated_w_e = None
        self._target_e_v = None


    def __init__(self, disc_idxs, div_idxs, observation_spec, action_spec, behavior_policy, config):
        super().__init__(config)

        self.disc_idxs = disc_idxs
        self.div_idxs = div_idxs

        self._process_config(config)

        self.device = config["device"]

        self._iteration = 0
        self._optimizers = dict()

        # newly added
        observation_spec = len(disc_idxs)
        # value sees disc_idxs (what we want to imitate), discriminator sees div_idxs (what we want to be diverse)

        if self._masked_value_net:
            #value_constructor = functools.partial(ValueSkillNetwork, num_skills=self._num_skills, num_inputs=observation_spec, hidden_sizes=self._hidden_sizes)
            value_constructor = functools.partial(MaskedValue, num_inputs=observation_spec, hidden_sizes=self._hidden_sizes, 
                                                  num_skills=self._num_skills, p_drop=self._p_drop)
        else:
            value_constructor = functools.partial(ValueNetwork, hidden_sizes=self._hidden_sizes, num_inputs=observation_spec+self._num_skills)

        self._v_network = value_constructor().to(self.device)      
        self._target_v_network = value_constructor().to(self.device)
        self._target_v_network.load_state_dict(self._v_network.state_dict())        
        self._optimizers["v"] = torch.optim.Adam(
            self._v_network.parameters(), self._lr, weight_decay=self._v_l2_reg
        )
        
        self._skill_discriminator = (
                EnsembleOfDiscriminators(
                    num_inputs=len(div_idxs),
                    num_classes=self._num_skills,
                    n=4,
                    w_disdain=self._w_disdain,
                    hidden_shape=[self._disc_hidden]*3,
                    entropy_penalty=self._disc_entropy_penalty,
                )
                .to(self.device)
            )
        self._optimizers["discriminator"] = torch.optim.Adam(
            self._skill_discriminator.parameters(), self._disc_lr
        )
        

        # f-divergence functions
        if self._f == "chi":
            self._f_fn = lambda x: 0.5 * (x - 1) ** 2
            self._f_star_prime = lambda x: torch.relu(x + 1)
            self._f_star = lambda x, _: 0.5 * x**2 + x
        elif self._f == "kl":
            self._f_fn = lambda x: x * torch.log(x + 1e-10)
            self._f_star_prime = lambda x, _: torch.exp(x - 1)
        else:
            raise NotImplementedError()

        # policy
        self._policy_network = TanhNormalPolicy(
            observation_spec + self._num_skills,
            action_spec,
            hidden_sizes=self._hidden_sizes,
            mean_range=config["mean_range"],
            logstd_range=config["logstd_range"],
        ).to(self.device)


        self._optimizers["policy"] = torch.optim.Adam(
            self._policy_network.parameters(), self._lr
        )

        if self._use_policy_entropy_constraint:
            self._log_ent_coeff = torch.nn.Parameter(torch.zeros(1, device=self.device),  requires_grad=True)
            self._optimizers["ent_coeff"] = torch.optim.Adam(
                [self._log_ent_coeff], self._lr
            )

        # inverse sigmoid
        self._init_lambda = np.log(np.divide(self._init_frac , (1 - self._init_frac)))
        
        self._lambda = torch.nn.Parameter(
            torch.ones(self._num_skills + 1, requires_grad=True, device=self.device)
            * self._init_lambda,
            requires_grad=True,
        )
        self._optimizers["lambda"] = torch.optim.Adam([self._lambda], self._lambda_lr)

        self._skills = torch.eye(self._num_skills, device=self.device)        # set default eval skill
        self._eval_skill = self._skills[0]


    def get_lambdas(self, skills):
        return self._lambda[torch.max(skills, -1)[1]][..., None]

    
    def get_state_dict(self):
        optimizer_states = {k: v.state_dict() for k, v in self._optimizers.items()}
        network_states = {
            "policy": self._policy_network.state_dict(),
            "v": self._v_network.state_dict(),
            "w_e_expert": self.expert_w_e if self.expert_w_e is not None else None,
            "w_e_skills": self._untruncated_w_e if self._untruncated_w_e is not None else None,
            'discriminator':  self._skill_discriminator.state_dict(),
        }

        return {
            "network_states": network_states,
            "optimizer_states": optimizer_states,
            "lambda": self._lambda.data,
            "log_ent_coeff": self._log_ent_coeff.data if self._use_policy_entropy_constraint else None,
        }

    def load_state_dict(self, state_dict):

        network_states = state_dict["network_states"]
        optimizer_states = state_dict["optimizer_states"]

        self._policy_network.load_state_dict(network_states["policy"])
        self._v_network.load_state_dict(network_states["v"])     
        self._skill_discriminator.load_state_dict(network_states["discriminator"])

        for k, v in self._optimizers.items():
            v.load_state_dict(optimizer_states[k])
        self._lambda.data = state_dict["lambda"]
        self._log_ent_coeff.data = state_dict["log_ent_coeff"]

        print('Loaded state dicts!')


    def eval_discriminator(self, data_loader):
        rewards = []
        # compute discriminator logps
        data_loader.set_shuffle(False)
        logps = []
        with torch.no_grad():
            for batch in iter(data_loader):
                skill_reward, logits, inf = self._skill_discriminator.skill_reward((batch['observations'][:, self.div_idxs],))
                rewards.append(skill_reward)
                logps.append(inf['logp'].cpu())
        skill_rewards = torch.cat(rewards, dim=0).detach().cpu()
        logps = torch.cat(logps, dim=0).detach().cpu()
        skill_p = logps.exp()
        data_loader.set_shuffle(True)
        skill_w_e = stable_softmax(logps.T, 1)
        
        pairwise_distance_of_skills = torch.cdist(skill_w_e[None], skill_w_e[None], p=1)[0]

        # create heatmap
        pairwise_l1_distance = pairwise_distance_of_skills.cpu().numpy()
        wandb.log({"pairwise_l1_distance": wandb.Image(pairwise_l1_distance, caption="pairwise_l1_distance")}, commit=False)


        entropy = -(skill_p * np.log(skill_p+1e-8)).sum(-1)
        avg_entropy = entropy.mean()
        max_p = torch.max(skill_p, dim=1)[0]
        min_p = torch.min(skill_p, dim=1)[0]
        max_logp = torch.log(max_p+1e-8).mean()
        min_logp = torch.log(min_p+1e-8).mean()
        
        pairwise_distance_of_skills_aug = pairwise_distance_of_skills +  torch.eye(pairwise_distance_of_skills.shape[0], device=pairwise_distance_of_skills.device) * 1e8

        return {'dataset_entropy': avg_entropy,
                'dataset_max_p': max_p.mean(), 
                'dataset_max_logp': max_logp,
                'dataset_min_logp': min_logp,
                'dataset_avg_skill_l1_distance': pairwise_distance_of_skills.sum()/(self._num_skills**2 - self._num_skills),
                'dataset_min_skill_l1_distance': pairwise_distance_of_skills_aug.min(),
                }
    
    def discriminator_loss(self, logits, skills, w_e, result={}):

        discriminator_loss = self._skill_discriminator.loss(logits, skills, w_e, reduction='mean')
        result["discriminator_loss"] = discriminator_loss
        return result

    def train_step_discriminator(self, batch, iteration, w_e=None):
        """Train discriminator on batch of data."""
        observation = batch['observations'].to(self.device)
        skills = batch['skills']
        e_v = batch['e_v']

        discriminator_reward, disc_logits = self.discriminator_reward(
            self._skill_discriminator,
            observation, skills
        )
        res = {}
        res['discriminator_reward'] = discriminator_reward.mean()
        idxs = torch.all(skills==0, dim=1)
        skills = skills[~idxs]
        e_v = e_v[~idxs]
        res = self.discriminator_loss(
                        disc_logits, skills, e_v.exp().detach(), result={}
                    )
        self._optimizers["discriminator"].zero_grad()
        res["discriminator_loss"].backward()
        self._optimizers["discriminator"].step()

        if self._assertions:
            assert torch.all(torch.isfinite(res["discriminator_loss"]))
        return res
    
    def train_discriminator_stage(self, data_loader, iteration):
        """Loop to train discriminator."""
        t_start = time.time()
        # compute optimal discriminator
        res_acc = defaultdict(lambda: 0)
        it = iter(data_loader)
        for _ in range(self._d_stage_steps):
            try:
                batch = next(it)
            except StopIteration:
                it = iter(data_loader)
                batch = next(it)
            batch = self._unpack_batch(batch)
            res = self.train_step_discriminator(batch, iteration)
            for k, v in res.items():
                res_acc[k] += v
        for k, v in res_acc.items():
            res_acc[k] = v / self._d_stage_steps # average over steps
        print(f"Discriminator training took {time.time() - t_start:.3f}s")
        
        rewards = []
        # compute discriminator logps
        data_loader.set_shuffle(False)
        logps = []
        with torch.no_grad():
            for batch in iter(data_loader):
                skill_reward, logits, inf = self._skill_discriminator.skill_reward((batch['observations'][:, self.div_idxs],))
                rewards.append(skill_reward)
                logps.append(inf['logp'].cpu())
        skill_rewards = torch.cat(rewards, dim=0).detach().cpu()
        logps = torch.cat(logps, dim=0).detach().cpu()
        data_loader.set_shuffle(True)
        self._disc_rewards = skill_rewards

        entropy = -(logps.exp() *logps).sum(-1)
        avg_entropy = entropy.mean()
        print(f"Discriminator training took {time.time() - t_start:.3f}s")
        return dict(dataset_entropy=avg_entropy, **res_acc)
    
    

    def get_expert_state(self):
        return {
            'v_network': self._expert_v_network.state_dict()
        }
    def load_expert_state(self, state):
        self._expert_v_network.load_state_dict(state['v_network'])
    

    def pre_train_stage_expert(self, data_loader, iteration):
        t_start = time.time()
        res_acc = defaultdict(lambda: 0)
        for _ in tqdm.tqdm(range(self._pretrain_expert_epochs), total=self._pretrain_expert_epochs):
            it = iter(data_loader)
            for batch in it:
                res = self.train_step_value_expert(batch, iteration)
                for k, v in res.items():
                    res_acc[k] += v
                    
                    
        data_loader.set_shuffle(False)
        self.compute_expert_w_e(data_loader)
        data_loader.set_shuffle(True)
        print(f"Pretraining expert took {time.time() - t_start:.3f}s")
        # compute expert w_e



            
    def train_value_stage(self, data_loader, iteration):
        """Loop to train value network."""
        t_start = time.time()
        it = iter(data_loader)
        res_acc = defaultdict(lambda: 0)
        for _ in range(self._v_stage_steps):
            try:
                pre_unpack_batch = next(it)
                batch = self._unpack_batch(pre_unpack_batch)
            except StopIteration:
                it = iter(data_loader)
            res = self.train_step_value(batch, iteration)
            for k, v in res.items():
                res_acc[k] += v
            self.polyak_update(self._target_v_network, self._v_network)
        for k, v in res_acc.items():
            res_acc[k] = v / self._v_stage_steps            
        print(f"\nValue training took {time.time() - t_start:.3f}s")
        return res_acc


    def _compute_e_v(self,batch, expert_log_ratio, observation, next_observation, terminal, skills, iteration, v_network):
        with torch.no_grad():
             # Discriminator reward
            if iteration > self._burnin_steps:
                discriminator_reward, _ = self.discriminator_reward(self._skill_discriminator, skills=skills, observations=observation)
                lambdas = self.get_lambdas(skills)
                frac = torch.sigmoid(lambdas)
            else:
                discriminator_reward = torch.tensor(0.0)
                frac = torch.tensor(1.0)
            
            reward = (
                frac * self.get_expert_reward(batch)
                + (1 - frac) * discriminator_reward * self._w_disc_reward
            )

            if self._assertions:
                assert torch.all(torch.isfinite(reward))

            n = len(observation)

            # Shared network values
            # v_values, _ = v_network((observation[:, self.disc_idxs], skills))
            # next_v_values, _ = v_network((next_observation[:, self.disc_idxs], skills))

            stacked_obs = torch.cat((observation, next_observation), dim=0)
            stacked_skills = torch.cat((skills, skills), dim=0)
            stacked_v_values, _ = v_network((stacked_obs[:, self.disc_idxs], stacked_skills))
            v_values, next_v_values = stacked_v_values[:n], stacked_v_values[n:]
            e_v = reward + (1 - terminal) * self._gamma * next_v_values - v_values
        return e_v  
        
    def discriminator_reward(self, disc, observations, skills):
        skill_reward, logits, info = disc.skill_reward(
            (observations[:, self.div_idxs],), skills
        )
        return skill_reward.detach(), logits
    

    def train_step_value(self, batch, iteration):
        
        batch_size = batch['observations'].shape[0]
        if not 'skills' in batch:
            skill_idxs = torch.randint(0, len(self._skills), (batch_size,), device=self.device)
            batch['skills'] = self._skills[skill_idxs]
        
        
        expert_ratio  =  self.expert_w_e[:, batch['idxs'].to(self.expert_w_e.device).flatten()].to(self.device).T
        expert_log_ratio = torch.log((expert_ratio+1e-24)*self._num_skills)/self._num_skills
        if iteration >= self._burnin_steps:
            with torch.no_grad():
                discriminator_reward, disc_logits = self.discriminator_reward(self._skill_discriminator,
                                                                              observations=batch['observations'], skills=batch['skills'])
                assert discriminator_reward.shape == batch['classifier_rewards'].shape
                lambdas = self.get_lambdas(batch['skills'])
                frac = torch.sigmoid(lambdas)
        else:
            frac = torch.tensor(1.0)
            discriminator_reward = torch.tensor(0.0)
            
        
        # Shared network values
        initial_v_values, _ = self._v_network((batch['initial_observations'][:, self.disc_idxs], batch['skills']))
        v_values, _ = self._v_network((batch['observations'][:, self.disc_idxs], batch['skills']))
        next_v_values, _ = self._v_network((batch['next_observations'][:, self.disc_idxs], batch['skills']))

        imitation_reward = self.get_expert_reward(batch)

        if self._assertions:
            assert torch.all(torch.isfinite(initial_v_values))
            assert torch.all(torch.isfinite(v_values))
            assert torch.all(torch.isfinite(discriminator_reward))
            
        reward = (
                frac * imitation_reward
                + (1 - frac) * discriminator_reward * self._w_disc_reward
            )
        if self._assertions:
            assert torch.all(torch.isfinite(reward))        
        # gradient needs to propagate through next_v_values also
        e_v = reward.detach()  + (1 - batch['terminals']) * self._gamma * next_v_values - v_values
        
        # changing batch in place!
        batch['e_v'] = e_v.detach()

        # compute value function loss (Equation 20 in the paper)
        loss_result = self.v_loss(initial_v_values, e_v, result={})
        
        self._optimizers["v"].zero_grad()
        loss_result["v_loss"].backward()
        self._optimizers["v"].step()

        loss_result["discriminator_reward"] = discriminator_reward.mean()
        loss_result["expert_log_ratio"] = expert_log_ratio.mean()

        return loss_result
    
    def get_expert_reward(self, batch):

        reward_type = self._reward_type
        if reward_type == 'e_v_exp':
            return batch['expert_e_v'].exp()
        elif reward_type == 'e_v':
            return batch['expert_e_v']
        elif reward_type == 'log_w_e':
            return  torch.log(batch['expert_w_e']*self._num_skills+EPS)/self._num_skills
        elif reward_type == 'classifier_reward':
            return batch['classifier_rewards']
    
    def compute_violation(self, data_loader, iteration):
        # just compute divergence
        w_e = self._untruncated_w_e
        w_e_skills = w_e[:self._num_skills]
        w_e_expert = self.expert_w_e.to(w_e_skills.device)
        phi = w_e_skills * (torch.log(w_e_skills+EPS) - torch.log(w_e_expert+EPS))
        phi = torch.sum(phi, dim=1)
        violation = phi-self._epsilon
        return violation, phi
    
    def train_lambda_stage(self, data_loader, iteration, violation, phi):
        """Train lambda on batch of data."""
        t_start = time.time()
        loss = torch.mean(-self._lambda[:self._num_skills]*violation.to(self.device))
        lmbda = self._lambda[:-1]

        self._optimizers["lambda"].zero_grad()
        loss.backward()
        self._optimizers["lambda"].step()
        print(f"Lambda training took {time.time() - t_start:.3f}s, phi={phi.mean().item():.3f}, violation={violation.mean().item():.3f}, lambda={lmbda.mean().item():.3f}")
        
        assert torch.all(torch.isfinite(self._lambda[:self._num_skills]))

    def compute_all_w_e(self, data_loader, iteration):
        t_start = time.time()
        data_loader.set_shuffle(False)
        train_batch_size = data_loader._batch_size
        data_loader.set_batch_size(20_000)
        if self._untruncated_w_e is None or not hasattr(self, 'e_vs'):
            self.expert_w_e = self.expert_w_e.to(self.device)
            if self.device == 'cuda':
                self.e_vs = torch.zeros(( len(data_loader), len(self._skills))).pin_memory()
            else:
                self.e_vs = torch.zeros(( len(data_loader), len(self._skills))).to(self.device)
        # TODO this can break if batch size is too large
        with torch.no_grad():
            loop_start = time.time()
            for batch in iter(data_loader):
                pre_unpacked_batch = batch
                batch = self._unpack_batch(batch)
                observation = batch['observations']
                eta_e = self.expert_w_e[:, batch['idxs'].flatten()].T
                log_eta_e = torch.log((eta_e + 1e-24) * self._num_skills) / self._num_skills
                next_observation = batch['next_observations']
                terminal = batch['terminals'].unsqueeze(1).reshape(-1, 1)

                e_v = self._compute_e_v(batch, self.get_expert_reward(batch), observation,
                                        next_observation, terminal, batch['skills'],
                                        iteration, v_network=self._v_network).detach()
                # extracting importance weight (Equation 21 in the paper)
                if self._f == "kl":
                    w_e = e_v.detach() / self._temperature
                else:
                    w_e = self._f_star_prime(e_v, batch['skills'])

                w_e = w_e.reshape(len(self._skills),  -1).T
                e_v = e_v.reshape(len(self._skills), -1).T
                self.e_vs[pre_unpacked_batch['idxs'].flatten()] = e_v
            print('loop time', time.time() - loop_start)
            rest_start = time.time()

            w_es = stable_softmax(self.e_vs.T, dim=-1)            
            data_loader.set_batch_size(train_batch_size)
            self._untruncated_w_e = w_es
            
            # rescale
            q = torch.quantile(w_es, torch.tensor([0.95]).to(self.device), dim=1, keepdim=True)[0]
            w_es = torch.clamp(w_es, max=q)
            # self._min_ws = torch.min(w_es, dim=1, keepdim=True)[0]
            self._min_ws = 1./w_es.shape[-1] # divide by uniform probability
            w_es /= self._min_ws
            print('rest finish ', time.time() - rest_start)

            data_loader.set_batch_size(train_batch_size)
            data_loader.set_shuffle(True)
            data_loader.set_w_e({'w_e': w_es.T.cpu().numpy(), 'e_v': self.e_vs.cpu().numpy()})
            print(f"Computing w_e took {time.time() - t_start:.3f}s")
            return w_es


    def policy_train_step(self, batch):
        # policy learning (Equation 22 in the paper)
        if self._use_expert_importances:
            """Can we fit the expert just based on smodice importances?"""
            w_e = batch['expert_e_v'].detach().exp()
        else:
            w_e = batch['e_v'].detach().exp()

        pi_actions, pi_log_prob, loss_result = self.policy_loss(
            batch['observations'], batch['actions'], batch['skills'], 
            w_e, result={} # TODO changed this line!
        )
        self._optimizers["policy"].zero_grad()
        loss_result["policy_loss"].backward()
        self._optimizers["policy"].step()
        if self._use_policy_entropy_constraint:
            self._optimizers["ent_coeff"].zero_grad()
            loss_result["ent_coeff_loss"].backward()
            self._optimizers["ent_coeff"].step()
        return loss_result



    def train_step(
        self,
        batch,
        iteration,
        data_loader
    ):
        
        v_res = self.train_value_stage(data_loader, iteration)
        if not self._expert_only:
            w_e = self.compute_all_w_e(data_loader, iteration)
        else:
            w_e = self.expert_w_e
        pi_res = self.train_policy_stage(data_loader, iteration)


        if self._expert_only:
            return dict(**pi_res, **v_res)

        d_res = l_res = {}
        if iteration >= self._burnin_steps:
            print('Training discriminator..')
            d_res = self.train_discriminator_stage(data_loader, iteration)

            violation, phi = self.compute_violation(data_loader, iteration)

            if self._learn_lambda:
                if iteration == self._burnin_steps:
                    self._lambda.data[:self._num_skills] = torch.sign(violation)
                    print(f'Setting lambda automatically to ', self._lambda.data[:self._num_skills].cpu().numpy())
                self.train_lambda_stage(data_loader, iteration, violation, phi)

            for i in range(len(self._lambda[:-1])):
                l_res[f'lambda_stage/phi_{i}'] = phi[i]
                l_res[f'lambda_stage/violation_{i}'] = violation[i]
                l_res[f'lambda_stage/lmbda_{i}'] = self._lambda[:-1].data[i]
                l_res[f'lambda_stage/sigmoid_lmbda_{i}'] = torch.sigmoid(self._lambda.data[i])

            d_eval = self.eval_discriminator(data_loader)
            for k, v in d_eval.items():
                print(f"{k}: {v:.3f}")
            d_res.update(d_eval)

        res = dict(**v_res, **pi_res, **d_res, **l_res)

        # assert losses are finite
        if self._assertions:
            assert torch.all(torch.isfinite(w_e))
            assert torch.all(torch.isfinite(res["v_loss"]))
            assert torch.all(torch.isfinite(res["policy_loss"]))

        self._iteration += 1
        w_e = w_e.cpu()
        if w_e.shape[-1] > int(2e6):
            # subsample w_e
            w_e = w_e[:, ::10]
        quantiles = torch.tensor([0.5, 0.75, 0.9], device=w_e.device)
        w_e_quantiles = torch.quantile(w_e.detach(), quantiles)

        for i, wq in zip(
            quantiles, w_e_quantiles
        ):
            res[f"quantiles/w_e_q_{i:.2f}"] = wq
        res["quantiles/w_e_q_1.0"] = torch.max(w_e)
        res["quantiles/w_e_q_0.0"] = torch.min(w_e)
        
        return res



class DOISingleStep(DOI):
    def __init__(self, disc_idxs, div_idxs, observation_spec, action_spec, behavior_policy, config):
        super().__init__(disc_idxs, div_idxs, observation_spec, action_spec, behavior_policy, config)
        self._use_expert_importances = False
    

    def train_step(self, batch, iteration, data_loader):
        
        batch = self._unpack_stochastic_batch(batch)

        res = {}
        res.update(self.train_step_value(batch, iteration))
        res.update(self.policy_train_step(batch))
        res.update(self.train_step_discriminator(batch, iteration))

        if (iteration+1) % self._l_stage_freq == 0:
            l_res = {}
            with torch.no_grad():
                self.compute_all_w_e(data_loader, iteration)
                violation, phi = self.compute_violation(data_loader, iteration)
            if self._learn_lambda:
                if iteration == self._burnin_steps:
                    self._lambda.data[:self._num_skills] = torch.sign(violation)
                    print(f'Setting lambda automatically to ', self._lambda.data[:self._num_skills].cpu().numpy())
                self.train_lambda_stage(data_loader, iteration, violation, phi)

            for i in range(len(self._lambda[:-1])):
                l_res[f'lambda_stage/phi_{i}'] = phi[i]
                l_res[f'lambda_stage/violation_{i}'] = violation[i]
                l_res[f'lambda_stage/lmbda_{i}'] = self._lambda[:-1].data[i]
                l_res[f'lambda_stage/sigmoid_lmbda_{i}'] = torch.sigmoid(self._lambda.data[i])

            d_eval = self.eval_discriminator(data_loader)
            for k, v in d_eval.items():
                print(f"{k}: {v:.3f}")
            wandb.log({f'lagrange/{k}':v for k,v in dict(**l_res, **d_eval).items()})
            res.update(d_eval)
            res.update(l_res)
        return res
        