import torch
from .utils import stable_log_mean_exp
import time
from collections import defaultdict

EPS = 1e-6

class AbstractDOI:

    def __init__(self, config) -> None:
        self._normalize_z = True

    def get_eval_skills(self):
        return self._skills[:self._num_skills]

    def v_loss(self, initial_v_values, e_v, result={}):
        
        # Compute v loss
        v_loss0 = (1 - self._gamma) * initial_v_values

        if self._f in ["kl", "doi_kl"]:
            v_loss1 = stable_log_mean_exp(e_v.flatten(), 0)
        elif self._f == 'simple':
            v_loss1 = torch.mean(e_v**2)
            v_loss0 = v_loss0*0
        else:
            v_loss1 = torch.mean(self._f_star(e_v))

        assert torch.all(torch.isfinite(v_loss1)), "v_loss1 is not finite"
        v_loss = v_loss0 + v_loss1

        v_loss = torch.mean(v_loss)

        result.update(
            {
                "v_loss0": torch.mean(v_loss0),
                "v_loss1": torch.mean(v_loss1),
                "v_loss": v_loss,
            }
        )

        return result
    
    
    def _unpack_stochastic_batch(self, batch):
        """Unpacks batch into separate tensors for each skill while sampling skills."""
        # randomly assign skills to batch
        batch_size = batch['observations'].shape[0]
        skill_idxs = torch.randint(0, self._num_skills, (batch_size,))
        batch['skills'] = self._skills[skill_idxs].to(self.device).float()
        return batch


    def _unpack_batch(self, batch):
        """Unpacks batch into separate tensors for each skill."""


        new_batch = {}
        new_batch['skills'] = torch.repeat_interleave(self._skills, batch['observations'].shape[0], 0).to(self.device)
        
        if 'w_e' in batch:
            assert batch['w_e'].shape[0] == batch['observations'].shape[0]

        for k, v in batch.items():
            if k not in ['w_e', 'skills']:
                if v.ndim == 1:
                    v = v.unsqueeze(1)
                new_batch[k] = torch.tile(v, (len(self._skills),1)).to(self.device)
        if 'w_e' in batch:
            # because last dimension is number of skills
            new_batch['w_e'] = batch['w_e'].T.reshape(-1, 1).to(self.device) # (batch_size * num_skills, 1)
            assert new_batch['w_e'].shape[0] == new_batch['observations'].shape[0]
        
        return new_batch

    def policy_loss(self, observation, action, skills, w_e, result={}):
        if self._normalize_z:
            skills = skills*2 - 1
        
        # Compute policy loss
        (
            sampled_action,
            sampled_pretanh_action,
            sampled_action_log_prob,
            sampled_pretanh_action_log_prob,
            pretanh_action_dist,
        ), _ = self._policy_network((observation[:, self.disc_idxs], skills))

        # Entropy is estimated on newly sampled action.
        negative_entropy_loss = torch.mean(sampled_action_log_prob)

        # Weighted BC
        action_log_prob, _ = self._policy_network.log_prob(
            pretanh_action_dist, action, is_pretanh_action=False
        )

        f_red = torch.mean

        if self._disc_type == "bc":
            policy_loss = -torch.mean(action_log_prob)
        else:
            policy_loss = -f_red(w_e * action_log_prob)
            if self._assertions:
                assert torch.all(torch.isfinite(policy_loss)), "policy_loss is not finite"

        if self._use_policy_entropy_constraint:
            ent_coeff = torch.exp(self._log_ent_coeff).squeeze(0)
            policy_loss += ent_coeff * negative_entropy_loss

            ent_coeff_loss = (
                -self._log_ent_coeff
                * (sampled_action_log_prob + self._target_entropy).detach()
            )

            result.update(
                {
                    "ent_coeff_loss": torch.mean(ent_coeff_loss),
                    "ent_coeff": ent_coeff,
                }
            )

        result.update(
            {
                "w_e": w_e,
                "mu_w_e": torch.mean(w_e),
                "policy_loss": policy_loss,
                "negative_entropy_loss": negative_entropy_loss,
            }
        )

        return sampled_action, sampled_action_log_prob, result


    
            
    def _compute_e_v(self,expert_log_ratio, observation, next_observation, terminal, skills, iteration):
        with torch.no_grad():
             # Discriminator reward
            if iteration > self._burnin_steps:
                discriminator_reward, _ = self.discriminator_reward(
                    self._target_skill_discriminator,
                    observation, skills
                )
                lambdas = self.get_lambdas(skills)
                frac = torch.sigmoid(lambdas)
            else:
                discriminator_reward = 0
                frac = 1.0


            reward = (
                frac * (expert_log_ratio)
                + (1 - frac) * discriminator_reward * self._w_disc_reward
            )

            if self._assertions:
                assert torch.all(torch.isfinite(reward))

            # Shared network values
            v_values, _ = self._v_network((observation, skills))
            next_v_values, _ = self._v_network((next_observation, skills))

            e_v = reward + (1 - terminal) * self._gamma * next_v_values - v_values
        return e_v
    

    def train_policy_stage(self, data_loader, iteration):
        t_start = time.time()
        res_acc = defaultdict(lambda: 0)
        it = iter(data_loader)
        for _ in range(self._pi_stage_steps):
            try:
                batch = next(it)
            except StopIteration:
                it = iter(data_loader)
                batch = next(it)
            batch = self._unpack_batch(batch)
            res = self.policy_train_step(batch)
            for k, v in res.items():
                if k != 'w_e':
                    res_acc[k] += v
        for k, v in res_acc.items():
            res_acc[k] = v / self._pi_stage_steps
        print(f"Policy training took {time.time() - t_start:.3f}s")
        return res_acc        

    def policy_train_step(self, batch):
        batch = self._unpack_batch(batch)
        # policy learning (Equation 22 in the paper)
        pi_actions, pi_log_prob, loss_result = self.policy_loss(
            batch['observations'], batch['actions'], batch['skills'], 
            batch['e_v'].detach().exp(), result={} # TODO changed this line!
        )
        self._optimizers["policy"].zero_grad()
        loss_result["policy_loss"].backward()
        self._optimizers["policy"].step()
        if self._use_policy_entropy_constraint:
            self._optimizers["ent_coeff"].zero_grad()
            loss_result["ent_coeff_loss"].backward()
            self._optimizers["ent_coeff"].step()
        return loss_result
    
    def set_eval_skill(self, skill):
        self._eval_skill = skill

    
    def step(self, observation, gpu=False, eval_skills=None):
        """
        For case of gpu=True, we are using Isaac gym and parallel eval.
        observation: batch_size x obs_dim
        """
        if not gpu:
            if self._normalize_z:
                eval_skill = self._eval_skill*2 - 1
            with torch.no_grad():
                observation = torch.from_numpy(observation).to(self.device)
                x = torch.cat((observation, eval_skill[None]), dim=1)
                action = self._policy_network.deterministic_action(x)
            return action.detach().cpu(), None
        else:
            if self._normalize_z:
                eval_skills = eval_skills*2 - 1
            with torch.no_grad():
                x = torch.cat((observation, eval_skills), dim=1)
                action = self._policy_network.deterministic_action(x)
            return action.detach(), None
    
    def polyak_update(self, target, current, tau=None):
        if tau is None:
            tau = self._tau
        for param, target_param in zip(
            current.parameters(), target.parameters()
        ):
            target_param.data.copy_(
                self._tau * param.data + (1 - tau) * target_param.data
            )
