import numpy as np
import torch
import torch.nn as nn

from typing import Dict, Union, Tuple, Callable
from copy import deepcopy
from offlinerlkit.policy import BasePolicy
import torch.nn.functional as F

class EPICPolicy(BasePolicy):
    """
    Ensemble-Diversified Actor Critic <Ref: https://arxiv.org/abs/2110.01548>
    """
    def __init__(
        self,
        actor: nn.Module,
        critics: nn.ModuleList,
        actor_optim: torch.optim.Optimizer,
        critics_optim: torch.optim.Optimizer,
        tau: float = 0.005,
        gamma: float  = 0.99,
        max_q_backup: bool = False,
        deterministic_backup: bool = True,
        eta: float = 1.0,

        data = None,
        data_act = None,
        epic_k: int = 1,
        epic_alpha: float = 1.0,
        epic_lambd: float = 0.8,
        action_dim = None,
    ) -> None:

        super().__init__()
        self.actor = actor
        self.actor_old = deepcopy(actor)
        self.actor_old.eval()
        self.critics = critics
        self.critics_old = deepcopy(critics)
        self.critics_old.eval()

        self.actor_optim = actor_optim
        self.critics_optim = critics_optim

        self._tau = tau
        self._gamma = gamma
        self._max_q_backup = max_q_backup
        self._deterministic_backup = deterministic_backup
        self._eta = eta
        self._num_critics = self.critics._num_ensemble

        if data is not None:
            self.data = data
            self.data_act = data_act
            self.k = epic_k
            self.alpha = epic_alpha
            self.epic_lambd = epic_lambd
            self.action_dim = action_dim

    def train(self) -> None:
        self.actor.train()
        self.critics.train()

    def eval(self) -> None:
        self.actor.eval()
        self.critics.eval()

    def _sync_weight(self) -> None:
        for o, n in zip(self.critics_old.parameters(), self.critics.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)

    def select_action(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray:
        with torch.no_grad():
            action = self.actor(obs).cpu().numpy()
        if not deterministic:
            action = action + self.exploration_noise(action.shape)
            action = np.clip(action, -self._max_action, self._max_action)
        return action

    def learn(self, batch: Dict) -> Dict:
        batch_indexes, obss, actions, next_obss, rewards, terminals = \
            batch["batch_indexes"], batch["observations"], batch["actions"], batch["next_observations"], batch["rewards"], batch["terminals"]
        
        if self._eta > 0:
            actions.requires_grad_(True)

        # update actor
        #a, log_probs = self.actforward(obss)
        a = self.actor(obss)
        # qas: [num_critics, batch_size, 1]
        qas = self.critics(obss, a)

        with torch.no_grad():
            # Get similar actions and states
            similar_states = torch.tensor(self.data[batch_indexes]).to(device='cuda')
            similar_actions = torch.tensor(self.data_act[batch_indexes]).to(device='cuda')

            # Initialize best actions and min values
            best_actions = torch.zeros_like(similar_actions[:, 0, :]) if self.k > 1 else torch.zeros_like(similar_actions)
            min_values = torch.full((obss.size(0), 1), float('inf'), device='cuda')

            # Iterate through each similar state
            for i in range(self.k):
                if self.k == 1:
                    action_for_k, state_for_k = similar_actions, similar_states
                else:
                    action_for_k, state_for_k = similar_actions[:, i, :], similar_states[:, i, :]

                q_for_k = self.critics(obss, action_for_k)
                min_for_k, _ = torch.min(q_for_k, dim=0, keepdim=True)

                # Update best actions and min values
                if i == 0 or self.k == 1:
                    best_actions, best_states, min_values = action_for_k, state_for_k, min_for_k
                else:
                    update_mask = min_for_k < min_values
                    min_values = torch.where(update_mask, min_for_k, min_values)
                    best_actions = torch.where(update_mask, action_for_k, best_actions)
                    best_states = torch.where(update_mask, state_for_k, best_states)

            min_for_k, _ = torch.min(qas, dim=0, keepdim=True)
            update_mask = min_for_k < min_values
            best_actions = torch.where(update_mask, a, best_actions)
            best_states = torch.where(update_mask, obss, best_states)

        # Compute policy loss
        policy_action = self.actor(obss)
        state_distance = F.mse_loss(obss, best_states)
        action_distance = F.mse_loss(policy_action, best_actions)
        epic_loss = self.alpha * state_distance + action_distance
        actor_loss = -torch.min(qas, 0)[0].mean() + self.epic_lambd * epic_loss

        #actor_loss = -torch.min(qas, 0)[0].mean() + self._alpha * log_probs.mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        # update critic
        if self._max_q_backup:
            with torch.no_grad():
                batch_size = obss.shape[0]
                tmp_next_obss = next_obss.unsqueeze(1).repeat(1, 10, 1) \
                    .view(batch_size * 10, next_obss.shape[-1])
                tmp_next_actions = self.actor(tmp_next_obss)

                tmp_next_qs = self.critics_old(tmp_next_obss, tmp_next_actions) \
                    .view(self._num_critics, batch_size, 10, 1).max(2)[0] \
                    .view(self._num_critics, batch_size, 1)
                next_q = tmp_next_qs.min(0)[0]
        else:
            with torch.no_grad():
                next_actions = self.actor(next_obss)
                next_q = self.critics_old(next_obss, next_actions).min(0)[0]

                mean_q = self.critics_old(next_obss, next_actions).mean()
                min_q = self.critics_old(next_obss, next_actions).min()
                std_q = self.critics_old(next_obss, next_actions).std()

        # target_q: [batch_size, 1]
        target_q = rewards + self._gamma * (1 - terminals) * next_q
        # qs: [num_critics, batch_size, 1]
        qs = self.critics(obss, actions)
        critics_loss = ((qs - target_q.unsqueeze(0)).pow(2)).mean(dim=(1, 2)).sum()

        if self._eta > 0:
            obss_tile = obss.unsqueeze(0).repeat(self._num_critics, 1, 1)
            actions_tile = actions.unsqueeze(0).repeat(self._num_critics, 1, 1).requires_grad_(True)
            qs_preds_tile = self.critics(obss_tile, actions_tile)
            qs_pred_grads, = torch.autograd.grad(qs_preds_tile.sum(), actions_tile, retain_graph=True, create_graph=True)
            qs_pred_grads = qs_pred_grads / (torch.norm(qs_pred_grads, p=2, dim=2).unsqueeze(-1) + 1e-10)
            qs_pred_grads = qs_pred_grads.transpose(0, 1)
            
            qs_pred_grads = torch.einsum('bik,bjk->bij', qs_pred_grads, qs_pred_grads)
            masks = torch.eye(self._num_critics, device=obss.device).unsqueeze(dim=0).repeat(qs_pred_grads.size(0), 1, 1)
            qs_pred_grads = (1 - masks) * qs_pred_grads
            grad_loss = torch.mean(torch.sum(qs_pred_grads, dim=(1, 2))) / (self._num_critics - 1)

            critics_loss += self._eta * grad_loss

        self.critics_optim.zero_grad()
        critics_loss.backward()
        self.critics_optim.step()

        self._sync_weight()

        result =  {
            "loss/actor": actor_loss.item(),
            "loss/critics": critics_loss.item(),
            "q_mean": mean_q.item(),
            "q_min": min_q.item(),
            "q_std": std_q.item()
        }

        return result



