from typing import Tuple, Dict
from omegaconf import DictConfig
import numpy as np
import torch
from collections import deque
from torch.distributions import Categorical

class DRIS:
    
    def __init__(self, cfg: DictConfig) -> None:
        self.use_dris = cfg.get('enable', True)
        self.actor_strategy = cfg.get('actor_strategy', 'none')
        assert self.actor_strategy.lower() in ['none', 'dris', 'flip_dris']
        if self.actor_strategy.lower() != 'none':
            assert self.use_dris

        self.t = cfg.get('t', 1e8)
        self.clip = cfg.get('clip', np.log(2.0))
        self.auto_tune = cfg.get('auto_tune', True)
        self.clip_ratio_threshold = cfg.get('clip_r_th', 0.1)
        self.temp_mul = cfg.get('t_mul', 1.01)
        # TODO: specify base_loss_fn, uses MSE td_error directly
        self.is_flip = cfg.get('is_flip', True)
        self.update_freq = cfg.get('update_freq', 5)
        self.update_clip_r_info = deque([0] * self.update_freq, maxlen=self.update_freq)
        self.cached_weight = 1.0
        self.verbose = cfg.get('verbose', True)

    def __str__(self):
        return f'DRIS(enabled={self.use_dris}, T_0={self.t}, clip={self.clip}, ' \
            + f'auto_tune={self.auto_tune}, clip_r_th={self.clip_ratio_threshold}, t_mul={self.temp_mul}), ' \
            + f'is_flip={self.is_flip}'

    def loss(self, curr: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, float]]:
        # td_error = curr - target        # TODO: verify
        if self.is_flip:
            td_error = target - curr
        else:
            td_error = curr - target
        td_error_quad = td_error ** 2
        info = {}
        if self.use_dris:
            with torch.no_grad():
                logits = td_error / self.t
                upper_clip_ratio = (logits > self.clip).float().mean().item()
                lower_clip_ratio = (logits < -self.clip).float().mean().item()
                clipped_logits = torch.clip(logits, -self.clip, self.clip)
                dris_weight = torch.softmax(clipped_logits, dim=-1) * len(td_error)

                if self.verbose:
                    entropy = Categorical(probs = dris_weight).entropy()
                    eff = 1.0 / torch.sum(dris_weight**2) * len(td_error)**2
                    eff_ratio = eff / curr.shape[0]

                    info['eff/num'] = eff.item()
                    info['eff/ratio'] = eff_ratio.item()
                    info['weight/entropy'] = entropy.item()
                    info['weight/std'] = dris_weight.std().item()
                    info['weight/max'] = torch.max(dris_weight).item()
                    info['weight/min'] = torch.min(dris_weight).item()
                    info['weight/mean'] = torch.mean(dris_weight).item()
                    for percentage in [0.25, 0.50, 0.75]:
                        info[f'weight/quantile/{percentage}'] = torch.quantile(dris_weight, percentage).item()
                
                    outlier_err = torch.where(logits != clipped_logits, logits, 0.0)
                    outlier_err = outlier_err[torch.nonzero(outlier_err).flatten()]

                    if len(outlier_err):
                        info["outliers/min"] = torch.min(outlier_err).item()
                        info["outliers/max"] = torch.max(outlier_err).item()
                        info["outliers/abs_mean"] = torch.abs(outlier_err).mean().item()
                        info["outliers/abs_median"] = torch.abs(outlier_err).median().item()
                        info["outliers/median"] = outlier_err.median().item()
                        outliers_std, outliers_mean = torch.std_mean(outlier_err, unbiased=False) 
                        info["outliers/std"] = outliers_std.item()
                        info["outliers/mean"] = outliers_mean.item()

                    # study effect of clipping
                    # https://en.wikipedia.org/wiki/Bhattacharyya_distance
                    bh_coeff = torch.sqrt(
                        torch.softmax(clipped_logits, dim=-1) * torch.softmax(logits, dim=-1)
                    ).sum().item()

                    info["clip/bh_coeff"] = bh_coeff
                    info["clip/bh_dist"] = -np.log(bh_coeff)

                info['clip/upper_r'] = upper_clip_ratio
                info['clip/lower_r'] = lower_clip_ratio
                info['clip/all'] = upper_clip_ratio + lower_clip_ratio

                self.cached_weight = dris_weight

            loss = (dris_weight * td_error_quad).mean()
        else:
            loss = td_error_quad.mean()
        
        if self.verbose:
            with torch.no_grad():
                info['q_loss'] = loss.item()
                info['td/loss'] = loss.item()
                info['td/quad_loss'] = td_error_quad.mean().item()
                info['td/abs_mean'] = torch.abs(td_error).mean().item()
                info['td/abs_max'] = torch.abs(td_error).max().item()
                info['td/abs_min'] = torch.abs(td_error).min().item()
                info['td/abs_std'] = torch.abs(td_error).std().item()
                info['td/mean'] = td_error.mean().item()
                info['td/std'] = td_error.std().item()
                info['hyper/t'] = self.t
                info['hyper/clip'] = self.clip
        return loss, info

    def update(self, info: Dict[str, float], global_step: int):
        if self.use_dris and self.auto_tune:
            raise RuntimeError("not supported yet")
            self.update_clip_r_info.appendleft(info['clip/all'])
            if global_step % self.update_freq == 0:
                clip_r = np.array(self.update_clip_r_info)
                clip_r.sort()
                if clip_r[clip_r.size//2:].mean() > self.clip_ratio_threshold:
                    self.t = self.t * self.temp_mul
                else:
                    self.t = self.t / self.temp_mul

    def update_actor_loss(self, losses):
        strategy = self.actor_strategy.lower()
        if strategy == 'none':
            return losses.mean()
        losses = losses.flatten()
        assert self.cached_weight.shape == losses.shape
        if strategy == 'dris':
            return (self.cached_weight * losses).mean()
        elif strategy == 'flip_dris':
            unnormal_weight = 1.0 / self.cached_weight
            weight = unnormal_weight / unnormal_weight.sum()
            return (weight * losses).mean()
        else:
            raise RuntimeError(f"{self.actor_strategy} not support!!!")
