import numpy as np
import torch
import hydra
import scipy as sc

from utils.utils_fn import calculate_means_dict
from utils.utils_fn import segment_data_fn

FLOAT_EPSILON = 1e-8
EXP_MIN = -100
EXP_MAX = 80


def weights_fn(r, eta, bias=False):
    adv = r - torch.max(r, dim=0, keepdims=True)[0] if bias else r
    weights = torch.exp(torch.clip(adv.detach() / eta, EXP_MIN, EXP_MAX))
    return weights


def weights_fn_np(r, eta, bias=False):
    adv = r - np.max(r, axis=0, keepdims=True) if bias else r
    weights = np.exp(np.clip(adv / eta, EXP_MIN, EXP_MAX))
    return weights


def dual_np(eta, adv, eps):
    w = weights_fn_np(adv, eta, bias=True)
    loss = eta * eps + np.max(adv, axis=0) + eta * np.log(w.mean())
    return loss


def sample_kl(w):
    w = torch.clip(w, 1e-45, torch.inf)
    w = w / w.sum(dim=0)
    return torch.sum(w * torch.log(w + 1e-20), dim=0) + torch.log(
        torch.tensor(w.shape[0])
    )


class Loss_functions:
    def __init__(
        self,
        params,
        device,
    ):
        self.loss_name = params["name"]
        self.params = params["params"]

        if self.loss_name == "kl_bounded_loss":
            self.device = device
            self.loss = self.kl_loss
            self.calculate_global_eta = self.params.get("calculate_global_eta", True)
            self.calculate_eta = True
            self.log_eta = torch.nn.Parameter(torch.as_tensor([2.0], device=device))
            self.eta = torch.exp(self.log_eta)
            self.norm_weights = self.params.get("norm_weights", False)
        elif self.loss_name == "eff_sample":
            self.device = device
            self.loss = self.eff_sample_loss
            self.calculate_global_eta = self.params.get("calculate_global_eta", True)
            self.norm_weights = self.params.get("norm_weights", False)
            self.calculate_eta = True
            self.log_eta = torch.nn.Parameter(torch.as_tensor([2.0], device=device))
            self.eta = torch.exp(self.log_eta)
        else:
            raise NotImplementedError

    def __call__(self, rewards, log_policy, step=None):
        self.step = step
        return self.loss(rewards, log_policy)

    def awac_like_loss(self, rewards, log_policy, beta=None):

        beta = self.params["beta"] if beta is None else beta
        # weights = torch.exp(rewards / beta)
        weights = weights_fn(rewards, beta, bias=False)
        with torch.no_grad():
            mean_weights = weights.mean(dim=0, keepdim=True).item()
        if self.norm_weights:
            weights /= weights.sum(dim=0, keepdim=True) + 1e-12

        loss = -(log_policy * weights).mean()
        sample_weights = weights_fn(rewards, beta, bias=False)
        eff_samples = (
            (sample_weights.sum() ** 2 / ((sample_weights**2).sum() + 1e-12))
            .cpu()
            .detach()
            .numpy()
        )

        return loss, {"mean_weights": mean_weights, "eff_samples": eff_samples}

    def kl_loss(self, rewards, log_policy):
        if self.calculate_eta:
            eps = self.params["beta"]
            if len(rewards.shape) > 2:
                rewards = rewards.sum(dim=1)
            res = sc.optimize.minimize(
                dual_np,
                self.eta.cpu().detach().numpy(),
                args=(rewards.squeeze().cpu().numpy(), eps),
                method="L-BFGS-B",
                bounds=((1e-6, 1e8),),
            )
            self.eta = torch.tensor(res.x, device=self.device)
        if log_policy is None:
            return torch.tensor(0.0), {
                "policy_update_global/eta": self.eta.item(),
                "policy_update_global/sample_kl": sample_kl(
                    weights_fn(rewards, self.eta, bias=False)
                ),
            }
        else:
            loss, awac_stats = self.awac_like_loss(rewards, log_policy, beta=self.eta)

        return loss, {
            "policy_update/eta": self.eta,
            "policy_update/sample_kl": sample_kl(
                weights_fn(rewards, self.eta, bias=False)
            ),
            "policy_update/eff_samples": awac_stats["eff_samples"],
            "policy_update/mean_weights": awac_stats["mean_weights"],
        }

    def eff_sample_loss(self, rewards, log_policy):
        if self.calculate_eta:
            eta, eta_loss = self.find_eta_eff_sample(
                rewards, self.params["target_eff"], low=1e-2, high=1e2
            )
            self.eta = torch.tensor(eta, device=self.device)
        if log_policy is None:
            return torch.tensor(0.0), {
                "policy_update_global/eta": self.eta.item(),
                "policy_update_global/sample_kl": sample_kl(
                    weights_fn(rewards, self.eta, bias=False)
                ),
            }
        else:
            loss, awac_stats = self.awac_like_loss(rewards, log_policy, beta=self.eta)

        return loss, {
            "policy_update/eta": self.eta,
            "policy_update/sample_kl": sample_kl(
                weights_fn(rewards, self.eta, bias=False)
            ),
            "policy_update/eff_samples": awac_stats["eff_samples"],
            "policy_update/mean_weights": awac_stats["mean_weights"],
        }

    def get_eff_ss(self, reward, eta=-1):
        weights_stats = {}
        if eta == -1:
            eta = self.eta
        else:
            eta = eta
        weights = weights_fn(r=reward, eta=eta, bias=False)  # check eta

        if self.norm_weights:
            weights /= weights.sum(dim=0, keepdim=True) + 1e-12
        weights_stats["weights_mean"] = weights.mean().item()
        weights_stats["weights_std"] = weights.std().item()
        weights_stats["weights_min"] = weights.min().item()
        weights_stats["weights_max"] = weights.max().item()

        eff_samples_abs = (
            (weights.sum() ** 2 / ((weights**2).sum() + 1e-12)).cpu().detach().numpy()
        )

        weights_stats["eff_samples_abs"] = eff_samples_abs
        eff_samples = eff_samples_abs / weights.shape[0]
        weights_stats["eff_samples"] = eff_samples

        return weights_stats

    def find_eta_eff_sample(self, reward, target_eff, low=1e-2, high=1e2):
        """
        Find x in [low, high] such that func(x) is as close as possible to target.

        Parameters:
            func: callable, the function to evaluate
            target: float, the desired output value
            low: float, lower bound of search interval
            high: float, upper bound of search interval

        Returns:
            float: approximate x where func(x) is closest to target
        """
        # Define objective function (distance from target)

        objective = lambda x: abs(
            self.get_eff_ss(reward, eta=x)["eff_samples"] - target_eff
        )

        res = sc.optimize.minimize_scalar(
            objective, bounds=(low, high), method="bounded"
        )
        return res.x, objective(res.x)


class PolicyUpdates:
    """Policy Updates."""

    def __init__(
        self,
        env,
        device,
        actor_cfg,
        actor_lr,
        loss_fn,
        actor_betas,
        batch_size,
        actor_lr_min=-1,
        actor_lr_step_max=-1,
        sub_batch_size=-1,
        train_gradient_update=1,
        log_steps=1,
        weight_decay=0.0,
        normalize_reward=False,
        sa_policy=True,
    ):
        self.action_range = env.action_range
        self.action_dim = env.action_dim
        self.obs_dim = env.obs_dim

        self.device = torch.device(device)
        self.batch_size = batch_size

        if sub_batch_size < 0:
            self.sub_batch_size = batch_size
        else:
            self.sub_batch_size = sub_batch_size

        self.actor_cfg = actor_cfg

        self.actor_cfg.obs_dim = self.obs_dim
        self.actor_cfg.action_dim = self.action_dim

        self.actor = hydra.utils.instantiate(actor_cfg)
        self.actor = self.actor.to(self.device)

        self.actor_lr = actor_lr
        self.actor_betas = actor_betas
        self.weight_decay = weight_decay
        self.loss_fn_params = loss_fn

        if actor_lr_min > 0 and actor_lr_step_max > 0:
            self.actor_lr_min = actor_lr_min
            self.actor_lr_step_max = actor_lr_step_max
        else:
            self.actor_lr_min = self.actor_lr
            self.actor_lr_step_max = 1

        # optimizers
        self.actor_optimizer = torch.optim.AdamW(
            self.actor.parameters(),
            lr=self.actor_lr,
            betas=self.actor_betas,
            weight_decay=self.weight_decay,
        )
        self.actor_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.actor_optimizer,
            T_max=self.actor_lr_step_max,
            eta_min=self.actor_lr_min,
        )

        self.actor_loss_fn = Loss_functions(self.loss_fn_params, device=device)
        self.train_gradient_update = train_gradient_update
        self.normalize_reward = normalize_reward

        self.sa_policy = sa_policy

        self.log_steps = log_steps
        self.train()

    def train(self, training=True):
        self.training = training
        self.actor.train(training)

    def act(self, obs, sample=False):
        self.actor.eval()
        with torch.no_grad():
            if isinstance(obs, np.ndarray):
                obs = torch.FloatTensor(obs).to(self.device)
            obs = obs.unsqueeze(0)
            dist = self.actor(obs)
            action = dist.sample() if sample else dist.mean
            action = action.clamp(*self.action_range)
        assert action.shape[0] == 1
        self.actor.train()
        return action[0]

    def update_actor(self, obs, action, reward, step):

        dist = self.actor(obs)
        dist_mean = dist.mean
        entropy = dist.estimate_entropy()

        if self.sa_policy:
            log_prob = dist.log_prob(action).unsqueeze(-1)
        else:
            log_prob = dist.log_prob(action).sum(-1, keepdim=True)

            reward = reward.sum(1)

        actor_loss, loss_infos = self.actor_loss_fn(
            reward,
            log_prob,
            step=step,
        )

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        dist_stats = {
            "mean/mean": dist_mean.mean().item(),
            "mean/min": dist_mean.min().item(),
            "mean/max": dist_mean.max().item(),
            "mean/std": dist_mean.std().item(),
            "entropy": entropy.item(),
        }

        log_prob_metrics = {
            "mean": log_prob.mean().item(),
            "abs_mean": reward.abs().mean().item(),
            "min": log_prob.min().item(),
            "max": log_prob.max().item(),
            "std": log_prob.std().item(),
        }
        return actor_loss.item(), log_prob_metrics, dist_stats, loss_infos

    def update(
        self, replay_buffer, reward_model, feedback_collector, env, logger, step
    ):

        for index in range(self.train_gradient_update):
            unique_rate = 1.0

            (
                all_data_sa,
                ground_truth_score_sa,
                unique_rate,
                indices,
                all_data_segmented,
                ground_truth_score_segments,
            ) = feedback_collector.get_pref_policy_data()

            _, _, segments, sa = segment_data_fn(
                all_data_segmented, obs_dim=self.obs_dim
            )

            all_data_sa = sa.to(device=self.device)
            obs_sa = all_data_sa[..., : self.obs_dim]
            actions_sa = all_data_sa[..., self.obs_dim :]

            segments = segments.to(device=self.device)
            obs_segments = segments[..., : self.obs_dim]
            actions_segments = segments[..., self.obs_dim :]

            if self.sa_policy:
                all_data = all_data_sa
                obs = obs_sa
                actions = actions_sa
            else:
                all_data = segments
                obs = obs_segments
                actions = actions_segments

            reward = reward_model.reward_policy_update(
                obs_segments, actions_segments, return_sa=self.sa_policy
            )

            reward_statics = {}

            if self.normalize_reward:
                reward_statics["unnormalized_stats"] = {
                    "mean": reward.mean().item(),
                    "min": reward.min().item(),
                    "max": reward.max().item(),
                    "std": reward.std().item(),
                }

                reward = (reward - reward.mean()) / (reward.std() + 1e-8)

            actor_losses = []
            log_prob_metrics = []
            dist_stats = []
            all_loss_stats = []

            if all_data.shape[0] < self.batch_size or self.batch_size <= 0:
                batch_size = all_data.shape[0]
                if self.batch_size <= 0:
                    sub_batch_size = (
                        self.sub_batch_size if self.sub_batch_size > 0 else batch_size
                    )  # if we use whole dataset, then use sub_batch_size
                else:
                    sub_batch_size = all_data.shape[0]
            else:
                batch_size = self.batch_size
                sub_batch_size = self.sub_batch_size

            total_batch_index = np.random.permutation(batch_size)
            num_batches = int(np.ceil(batch_size / sub_batch_size))

            global_loss_metrics = {}
            if step <= reward_model.reward_train_step:
                if self.actor_loss_fn.calculate_global_eta:
                    self.actor_loss_fn.calculate_eta = True
                    if self.sa_policy:
                        reward_ss = reward.view(-1, 1)
                    else:
                        reward_ss = reward.sum(1)

                    loss, global_loss_metrics = self.actor_loss_fn(
                        rewards=reward_ss, log_policy=None
                    )
                    self.actor_loss_fn.calculate_eta = False

                    global_eff_metrics = self.actor_loss_fn.get_eff_ss(reward_ss)
                    global_loss_metrics["global_eff_updates"] = global_eff_metrics

            for batch_id in range(num_batches):
                batch_start = batch_id * sub_batch_size
                batch_end = min((batch_id + 1) * sub_batch_size, batch_size)
                batch_index = total_batch_index[batch_start:batch_end]

                obs_batch = obs[batch_index]
                reward_batch = reward[batch_index]
                action_batch = actions[batch_index]

                actor_loss, log_prob_metric, dist_stat, loss_stats = self.update_actor(
                    obs=obs_batch, action=action_batch, reward=reward_batch, step=step
                )
                actor_losses.append(actor_loss)
                log_prob_metrics.append(log_prob_metric)
                dist_stats.append(dist_stat)
                all_loss_stats.append(loss_stats)

            self.actor_scheduler.step()

            if index == self.train_gradient_update - 1 and step % self.log_steps == 0:

                mean_actor_loss = np.mean(actor_losses)
                mean_log_prob_metrics = calculate_means_dict(log_prob_metrics)
                mean_dist_stats = calculate_means_dict(dist_stats)
                mean_loss_stats = calculate_means_dict(all_loss_stats)

                reward_learned_stats = {
                    "mean": reward.mean().item(),
                    "min": reward.min().item(),
                    "max": reward.max().item(),
                    "std": reward.std().item(),
                }

                reward_statics["reward_learned"] = reward_learned_stats

                reward_statics = {
                    "mean": reward.mean().item(),
                    "min": reward.min().item(),
                    "max": reward.max().item(),
                    "std": reward.std().item(),
                }

                metrics = {
                    "reward_statics": reward_statics,
                    "loss": mean_actor_loss,
                    "log_prob_metrics": mean_log_prob_metrics,
                    "dist_stats": mean_dist_stats,
                    "unique_rate": unique_rate,
                    "batch_size": batch_size,
                    "loss_stats": mean_loss_stats,
                    "actor_lr": self.actor_scheduler.get_last_lr()[0],
                    **global_loss_metrics,
                }
                logger.log("metrics", metrics, step)

        return 1

    def start_training(self, env):
        self.reset()

    def reset(self):
        pass
