from rlf.algos.il.base_il import BaseILAlgo
import torch
from functools import *
import operator
import numpy as np
from collections import defaultdict, deque
from abc import ABC, abstractmethod
import rlf.rl.utils as rutils

class BaseIRLAlgo(BaseILAlgo):
    def __init__(self):
        super().__init__()
        self.traj_log_stats = defaultdict(list)

    def init(self, policy, args):
        super().init(policy, args)
        self.ep_log_vals = defaultdict(lambda:
                deque(maxlen=args.log_smooth_len))
        self.culm_log_vals = defaultdict(lambda:
                [0.0 for _ in range(args.num_processes)])

    @abstractmethod
    def _get_reward(self, step, storage, add_info):
        pass

    def _update_reward_func(self, storage):
        return {}
    
    def _get_reward_bonus_value(self):
        return self.bc_model_regularizer_reward_bonus


    def _get_bc_model_reward_bonus(self, bc_policy, step, storage):
        """
        Compute a per-transition reward bonus from a BC policy.

        Args
        ----
        bc_policy : torch.nn.Module
            Behaviour-cloning policy that returns action logits.
        step : int
            Current time-step index in `RolloutStorage`.
        storage : RolloutStorage-like buffer
            Must expose .obs, .actions, .masks just like SB3/PPO buffers.

        Returns
        -------
        torch.Tensor
            Shape [batch, 1] (e.g. 32). Each entry is either `bc_reward_bonus`
            or 0, depending on the MiniGrid rule below.
        """
        bc_reward_bonus = self._get_reward_bonus_value()  # constant bonus
        with torch.no_grad():
            # Collect batch data for this step
            state  = storage.obs[step]      # (batch, *state_shape)
            mask   = storage.masks[step]    # (batch, 1)
            action = storage.actions[step]  # (batch, 1) 

            # Forward through BC policy to obtain logits over discrete actions
            bc_ac_logits, *_ = bc_policy(state=state, rnn_hxs={}, mask=mask)

            if 'MiniGrid' in self.args.env_name:

                bc_esa_labels = bc_ac_logits.argmax(dim=1, keepdim=True)      # (batch, 1)

                # ---------------------------------
                #  Create reward-bonus tensor
                # ---------------------------------
                # Initialise to ones (bonus everywhere)
                bonus = torch.full_like(action, fill_value=bc_reward_bonus, dtype=torch.float32)
                # MiniGrid rule: give bonus only if
                #  1) action == expert-label  OR
                #  2) action >= 4   (→ “done” / irrelevant actions)
                match_mask   = (action == bc_esa_labels)
                ignore_mask  = (action >= 4)
                valid_mask   = match_mask | ignore_mask          # element-wise OR
                # bonus        = torch.where(valid_mask,
                #                         torch.tensor(bc_reward_bonus, device=bonus.device),
                #                         torch.tensor(0.0,               device=bonus.device))
                bonus = torch.where(valid_mask,
                                    torch.full_like(action, bc_reward_bonus, dtype=torch.float32),
                                    torch.zeros_like(action, dtype=torch.float32))

                
            elif 'FetchPickAndPlaceDiffHoldoutTS150' in self.args.env_name:
                raise NotImplementedError
            elif 'MBRLmaze2d' in self.args.env_name:
                raise NotImplementedError
            elif 'FetchPushEnvCustomTS500' in self.args.env_name:
                raise NotImplementedError
            else:
                raise NotImplementedError

            # return bonus, with shape (batch, 1)
            return bonus


    def update(self, storage):
        super().update(storage)
        # CLEAR ALL REWARDS so no environment rewards can leak to the IRL method.
        for step in range(self.args.num_steps):
            storage.rewards[step] = 0

        log_vals = self._update_reward_func(storage)
        add_info = {k: storage.get_add_info(k) for k in storage.get_extract_info_keys()}
        for k in storage.ob_keys:
            if k is not None:
                add_info[k] = storage.obs[k]

        for step in range(self.args.num_steps):
            rewards, ep_log_vals  = self._get_reward(step, storage, add_info)

            # add bc model bonus reward here
            if self.args.add_bc_model_regularizer_reward_bonus:
                # generate reward bonus from bc model
                bc_policy = self.policy.bc_policy
                assert bc_policy is not None, "bc model not loaded!"
                bc_bonus = self._get_bc_model_reward_bonus(bc_policy, step, storage)
                rewards += bc_bonus
                ep_log_vals['bc_model_reward_bonus'] = bc_bonus
            ep_log_vals['reward'] = rewards
            storage.rewards[step] = rewards

            for i in range(self.args.num_processes):
                for k in ep_log_vals:
                    self.culm_log_vals[k][i] += ep_log_vals[k][i].item()

                if storage.masks[step, i] == 0.0:
                    for k in ep_log_vals:
                        self.ep_log_vals[k].append(self.culm_log_vals[k][i])
                        self.culm_log_vals[k][i] = 0.0

        for k, vals in self.ep_log_vals.items():
            log_vals[f"culm_irl_{k}"] = np.mean(vals)

        return log_vals

    def on_traj_finished(self, trajs):
        pass
