import torch
import torch.nn as nn

import torch.optim as optim
from torch.distributions import Categorical
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import numpy as np
import random
from copy import deepcopy
import gym

from networks import Actor_categorical
from offline_rl.cql.networks import DDQN

from typing import Dict, Union


class BasePolicy(nn.Module):
    def __init__(self) -> None:
        super().__init__()
    
    def train() -> None:
        raise NotImplementedError
    
    def eval() -> None:
        raise NotImplementedError
    
    def select_action(
        self,
        obs: np.ndarray,
        deterministic: bool = False
    ) -> np.ndarray:
        raise NotImplementedError
    
    def learn(self, batch: Dict) -> Dict[str, float]:
        raise NotImplementedError


class CQLAgent():
    def __init__(self, state_size, action_size, tau,
                  gamma, alpha, cql_weight, temperature, max_q_backup, deterministic_backup,
                  with_lagrange, lagrange_threshold, cql_alpha_lr, num_repeat_actions,
                  hidden_size=256, device="cpu"):
        self.state_size = state_size
        self.action_size = action_size
        self.device = device
        self.tau = 1e-3
        self.gamma = 0.99
        
        self.network = DDQN(state_size=self.state_size,
                            action_size=self.action_size,
                            layer_size=hidden_size
                            ).to(self.device)

        self.target_net = DDQN(state_size=self.state_size,
                            action_size=self.action_size,
                            layer_size=hidden_size
                            ).to(self.device)
        
        self.optimizer = optim.Adam(params=self.network.parameters(), lr=1e-3)
        
    
    def get_action(self, state, epsilon = 0.1):
        if random.random() > epsilon:
            state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
            self.network.eval()
            with torch.no_grad():
                action_values = self.network(state)
            self.network.train()
            action = np.argmax(action_values.cpu().data.numpy(), axis=1)
        else:
            action = random.choices(np.arange(self.action_size), k=1)
        return action

    def train(self):
        self.network.train()
        self.target_net.train()

    def eval(self):
        self.network.eval()
        self.target_net.eval()

    def cql_loss(self, q_values, current_action):
        """Computes the CQL loss for a batch of Q-values and actions."""
        logsumexp = torch.logsumexp(q_values, dim=1, keepdim=True)
        q_a = q_values.gather(1, current_action)
    
        return (logsumexp - q_a).mean()

    def learn(self, experiences, alpha = 10):
        
        states, actions, rewards, next_states, dones, K_state, K_value, K_rewards = experiences["state"], experiences["actions"], experiences["rewards"], experiences["next_state"], experiences["terminals"], experiences["K_state"], experiences["K"], experiences["K_rewards"]
        with torch.no_grad():
            Q_targets_next = self.target_net(K_state).detach().max(1)[0].unsqueeze(1)
            Q_targets = K_rewards + ((self.gamma ** K_value).unsqueeze(-1) * Q_targets_next * (1 - dones))

        Q_a_s = self.network(states)
        Q_expected = Q_a_s.gather(1, actions)
        
        cql1_loss = self.cql_loss(Q_a_s, actions)

        bellman_error = F.mse_loss(Q_expected, Q_targets)
        
        q1_loss = alpha * cql1_loss + bellman_error
        
        self.optimizer.zero_grad()
        q1_loss.backward()
        clip_grad_norm_(self.network.parameters(), 1.)
        self.optimizer.step()

        # ------------------- update target network ------------------- #
        self.soft_update(self.network, self.target_net)
        return {"q1_loss":q1_loss.detach().item(), "cql1_loss": cql1_loss.detach().item(), "bellman_error": bellman_error.detach().item()}
        
        
    def soft_update(self, local_model, target_model):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(self.tau*local_param.data + (1.0-self.tau)*target_param.data)



class BCPolicy(BasePolicy):

    def __init__(
        self,
        actor: nn.Module,
        actor_optim: torch.optim.Optimizer,
        device
    ) -> None:

        super().__init__()
        self.actor = actor
        self.actor_optim = actor_optim
        self.device = device
    
    def train(self) -> None:
        self.actor.train()

    def eval(self) -> None:
        self.actor.eval()
    
    def get_action(self, state: np.ndarray, context: bool = False) -> np.ndarray:
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        self.actor.eval()
        with torch.no_grad():
            action = self.actor(state).cpu().numpy()
        self.actor.train()

        return action
    
    def learn(self, batch: Dict) -> Dict[str, float]:
        obss, actions = batch["state"], batch["actions"]
        
        a = self.actor(obss)
        actor_loss = ((a - actions).pow(2)).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        return {
            "loss/actor": actor_loss.item()
        }


class IQLPolicy(BasePolicy):
    """
    Implicit Q-Learning <Ref: https://arxiv.org/abs/2110.06169>
    """

    def __init__(
        self,
        actor: nn.Module,
        critic_q1: nn.Module,
        critic_q2: nn.Module,
        critic_v: nn.Module,
        actor_optim: torch.optim.Optimizer,
        critic_q1_optim: torch.optim.Optimizer,
        critic_q2_optim: torch.optim.Optimizer,
        critic_v_optim: torch.optim.Optimizer,
        action_space: gym.spaces.Space,
        tau: float = 0.005,
        gamma: float  = 0.99,
        expectile: float = 0.8,
        temperature: float = 0.1
    ) -> None:
        super().__init__()

        self.actor = actor
        self.critic_q1, self.critic_q1_old = critic_q1, deepcopy(critic_q1)
        self.critic_q1_old.eval()
        self.critic_q2, self.critic_q2_old = critic_q2, deepcopy(critic_q2)
        self.critic_q2_old.eval()
        self.critic_v = critic_v

        self.actor_optim = actor_optim
        self.critic_q1_optim = critic_q1_optim
        self.critic_q2_optim = critic_q2_optim
        self.critic_v_optim = critic_v_optim

        self.action_space = action_space
        self._tau = tau
        self._gamma = gamma
        self._expectile = expectile
        self._temperature = temperature

    def train(self) -> None:
        self.actor.train()
        self.critic_q1.train()
        self.critic_q2.train()
        self.critic_v.train()

    def eval(self) -> None:
        self.actor.eval()
        self.critic_q1.eval()
        self.critic_q2.eval()
        self.critic_v.eval()

    def _sync_weight(self) -> None:
        for o, n in zip(self.critic_q1_old.parameters(), self.critic_q1.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
        for o, n in zip(self.critic_q2_old.parameters(), self.critic_q2.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)

    def get_action(self, obs: np.ndarray, context, deterministic: bool = False) -> np.ndarray:
        # if len(obs.shape) == 1:
        #     obs = obs.reshape(1, -1)
        with torch.no_grad():
            dist = self.actor(torch.cat([obs, context], dim = -1))
            # dist = self.actor(obs)
            if deterministic:
                action = dist.mode()
            else:
                action = dist.sample()
        action = torch.clip(action, torch.tensor(self.action_space.low[0]), torch.tensor(self.action_space.high[0]))
        return action
    
    def _expectile_regression(self, diff: torch.Tensor) -> torch.Tensor:
        weight = torch.where(diff > 0, self._expectile, (1 - self._expectile))
        return weight * (diff**2)
    
    def learn(self, batch: Dict, embed_index_set, embed_codebook, project_out) -> Dict[str, float]:
        obss, actions, next_obss, rewards, terminals, low_actions = batch["state"], batch["actions"], \
            batch["next_state"], batch["rewards"], batch["terminals"], batch["low_actions"]
        
        # update value net
        with torch.no_grad():
            q1, q2 = self.critic_q1_old(obss, low_actions), self.critic_q2_old(obss, low_actions)
            q = torch.min(q1, q2)
        v = self.critic_v(obss)
        critic_v_loss = self._expectile_regression(q-v).mean()
        self.critic_v_optim.zero_grad()
        critic_v_loss.backward()
        self.critic_v_optim.step()

        # update critic
        q1, q2 = self.critic_q1(obss, low_actions), self.critic_q2(obss, low_actions)
        with torch.no_grad():
            next_v = self.critic_v(next_obss)
            target_q = rewards + self._gamma * next_v #(1 - terminals) * next_v
        
        critic_q1_loss = ((q1 - target_q).pow(2)).mean()
        critic_q2_loss = ((q2 - target_q).pow(2)).mean()

        self.critic_q1_optim.zero_grad()
        critic_q1_loss.backward()
        self.critic_q1_optim.step()

        self.critic_q2_optim.zero_grad()
        critic_q2_loss.backward()
        self.critic_q2_optim.step()

        # update actor
        with torch.no_grad():
            q1, q2 = self.critic_q1_old(obss, low_actions), self.critic_q2_old(obss, low_actions)
            q = torch.min(q1, q2)
            v = self.critic_v(obss)
            exp_a = torch.exp((q - v) * self._temperature)
            exp_a = torch.clip(exp_a, None, 100.0)

            action_quantize = F.embedding(torch.from_numpy(embed_index_set[actions.detach().cpu().numpy()]).to(embed_codebook.device).squeeze(), embed_codebook)
            action_context = project_out(action_quantize)
            
        # action_context = torch.zeros_like(action_context).to(action_context)
        
        dist = self.actor(torch.cat([obss, action_context], dim = -1))
        # dist = self.actor(obss)
        log_probs = dist.log_prob(low_actions)
        actor_loss = -(exp_a * log_probs).mean()

        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        self._sync_weight()

        return {
            "loss/actor": actor_loss.item(),
            "loss/q1": critic_q1_loss.item(),
            "loss/q2": critic_q2_loss.item(),
            "loss/v": critic_v_loss.item()
        }
    


class bc_goal_policy(BasePolicy):
    def __init__(
        self,
        actor: nn.Module,
        actor_optim: torch.optim.Optimizer,
        device
    ) -> None:

        super().__init__()
        self.actor = actor
        self.actor_optim = actor_optim
        self.device = device
    
    def train(self) -> None:
        self.actor.train()

    def eval(self) -> None:
        self.actor.eval()
    
    def get_action(self, state: np.ndarray, context: bool = False, deterministic = False) -> np.ndarray:
        action = self.actor(state, context)
        if deterministic and len(action.shape) == 1 :
            action = torch.tensor([round(action.item())])

        return action


    def learn(self, batch: Dict, embed_index_set, embed_codebook, project_out) -> Dict[str, float]:
        obss, actions, low_actions = batch["state"], batch["actions"], batch["low_actions"] #\batch["rewards"], batch["next_state"], batch["terminals"]

        action_quantize = F.embedding(torch.from_numpy(embed_index_set[actions.detach().cpu().numpy()]).to(embed_codebook.device).squeeze(), embed_codebook)
        action_context = project_out(action_quantize)
        
        a = self.actor(obss, action_context)
        actor_loss = ((a - low_actions).pow(2)).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        return {
            "loss/actor": actor_loss.item()
        }
    



class bc_goal_categorical_policy(BasePolicy):
    def __init__(
        self,
        actor: Actor_categorical,
        actor_optim: torch.optim.Optimizer,
        device
    ) -> None:

        super().__init__()
        self.actor = actor
        self.actor_optim = actor_optim
        self.device = device
    
    def train(self) -> None:
        self.actor.train()

    def eval(self) -> None:
        self.actor.eval()
    
    def get_action(self, state: np.ndarray, context: bool = False, deterministic = True) -> np.ndarray:
        action = self.actor(state, context)
        if deterministic:
            action = action.argmax(-1).unsqueeze(-1)
        else: 
            action = nn.Softmax(dim=-1)(action)
            action_dist = Categorical(action)
            action = action_dist.sample()
        return action

    def learn(self, batch: Dict, embed_index_set, embed_codebook, project_out) -> Dict[str, float]:
        obss, actions, low_actions = batch["state"], batch["actions"], batch["low_actions"] #\batch["rewards"], batch["next_state"], batch["terminals"]

        action_quantize = F.embedding(torch.from_numpy(embed_index_set[actions.detach().cpu().numpy()]).to(embed_codebook.device).squeeze(), embed_codebook)
        action_context = project_out(action_quantize)
        
        a = self.actor(obss, action_context)
        actor_loss = F.cross_entropy(a, low_actions.squeeze(1).long() if len(low_actions.shape) > 1 else low_actions.long())
        # actor_loss = ((a - low_actions).pow(2)).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        return {
            "loss/actor": actor_loss.item()
        }


