
import dataclasses as dc
import math
import numpy as np
import torch
import torch.nn.functional as F
import typing as ty

from copy import copy, deepcopy
from gym import spaces
from pathlib import Path

from .base import ReplayBuffer, polyak_update, check_gradient, check_param_norm
from .networks import ActionValueNetwork4ContinuousAction
from .policy import SquashedGaussianPolicy, MixturedSquashedGaussianPolicy, RealNVPPolicy, SequantialGaussRealNVPPolicy

from .sac_clean import SoftQNetwork, Actor


@dc.dataclass
class Param4SAC:

    # for commom
    batch_size      : int = 256
    buffer_size     : int = int(1e7)
    lr              : float = 3e-3
    # l2_reg          : float = 0.0
    gamma           : float = 0.99
    seed            : int = 7
    grad_clipping   : float = None
    cuda            : bool = True
    # n_hidden        : list = [128, 128]
    n_hidden        : list = dc.field(default_factory=list)
    layer_norm      : bool = False
    use_double      : bool = False
    polyak          : float = 0.995
    # target_update_every: int = 1

    # for actor
    lr_policy: float = 3e-4
    policy_type: str = 'rnvp'
    n_mixtures: int = 1
    single_head_gmn: bool = False
    autotune_ent_coeff: bool = True
    coeff_entropy: float = 0.1
    max_coeff_entropy: float = 1e6
    log_std_min: float = -20.
    log_std_max: float = 2.
    min_action: float = None
    max_action: float = None
    action_scaler: str = 'spinningup'

    qf_emulate_clean_rl: str = True

    n_projection_samples: int = 1
    n_projection_samples_per_dim: int = None



class SoftActorCritic(torch.nn.Module):

    '''
        Soft Actor Ctiric
    '''

    def __init__(
        self,
        observation_space, action_space,
        batch_size: int = 16,
        lr: float = 1e-2,
        gamma: float = 0.99,
        seed: int = 13,
        buffer_size: int = int(1e4),
        grad_clipping: float = None,
        cuda: bool = False,
        n_hidden: list = [256, 256], # used if for representation='NN'
        layer_norm: bool = False,
        use_double: bool = False,
        polyak: float = 0.995,
        # actor
        lr_policy: float = 1e-3,
        policy_type: str = 'gaussian',
        n_projection_samples: int = 1,
        n_projection_samples_per_dim: int = None,
        n_mixtures: int = 1,
        single_head_gmn: bool = False,
        autotune_ent_coeff: bool = True,
        coeff_entropy: float = 0.9,
        max_coeff_entropy: float = 1e6,
        log_std_min: float = -20.,
        log_std_max: float = 2.,
        min_action: float = -1.,
        max_action: float = 1.,
        action_scaler: str = 'spinningup',
        qf_emulate_clean_rl: str = True,
     ):
        super(SoftActorCritic, self).__init__()

        self.seed = seed
        self.dice = np.random.RandomState(seed)
        self.device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")
        self.gamma = gamma
        self.batch_size = batch_size
        self.lr = lr
        self.lr_policy = lr_policy
        self.grad_clipping = grad_clipping
        self.n_hidden = n_hidden
        self.layer_norm = layer_norm
        self.use_double = use_double
        self.polyak = polyak
        self.policy_type = policy_type

        self.observation_space = observation_space
        self.action_space = action_space
        self.qf_emulate_clean_rl = qf_emulate_clean_rl
        self.dim_ob = observation_space.shape[0]
        self.dim_ac = action_space.shape[0]
        self.register_buffer(
            "action_scale", torch.tensor((action_space.high - action_space.low) / 2.0,
            dtype=torch.float32, device=self.device)
        )
        self.register_buffer(
            "action_bias", torch.tensor((action_space.high + action_space.low) / 2.0,
            dtype=torch.float32, device=self.device)
        )

        self.buffer = ReplayBuffer(self.observation_space, self.action_space, buffer_size)

        ''' Prepare Critic '''
        self.q_net = self._build_network()
        self.q_params = []
        self.q_params += list(self.q_net.parameters())
        print(f'q_net = {self.q_net}')

        if self.use_double:
            self.q_net2 = self._build_network()
            self.q_params += list(self.q_net2.parameters())
            print(f'q_net2 = {self.q_net2}')
        self.optimizer = torch.optim.Adam(self.q_params, lr=lr)

        """ deepcopied module is on the same device as the original. """
        if self.qf_emulate_clean_rl:
            self.target_q_net = SoftQNetwork(observation_space, action_space).to(self.device)
            self.target_q_net.load_state_dict(self.q_net.state_dict())
            if self.use_double:
                self.target_q_net2 = SoftQNetwork(observation_space, action_space).to(self.device)
                self.target_q_net2.load_state_dict(self.q_net2.state_dict())
        else:
            self.target_q_net = deepcopy(self.q_net)
            for p in self.target_q_net.parameters():
                p.requires_grad = False
            if self.use_double:
                self.target_q_net2 = deepcopy(self.q_net2)
                for p in self.target_q_net2.parameters():
                    p.requires_grad = False

        self.n_mixtures = n_mixtures
        if self.policy_type == 'gaussian':
            if self.n_mixtures > 1:
                self.policy = MixturedSquashedGaussianPolicy(
                    observation_space, action_space,
                    n_mixtures=n_mixtures, single_head=single_head_gmn,
                    action_scaler=action_scaler,
                    min_action=min_action, max_action=max_action,
                    n_hidden=n_hidden, device=self.device,
                    representation='NN',
                    log_std_min=log_std_min, log_std_max=log_std_max,
                ).to(self.device)
            else:
                self.policy = SquashedGaussianPolicy(
                    observation_space, action_space,
                    action_scaler=action_scaler,
                    min_action=min_action, max_action=max_action,
                    n_hidden=n_hidden, device=self.device,
                    representation='NN',
                    log_std_min=log_std_min, log_std_max=log_std_max,
                ).to(self.device)
        elif self.policy_type == 'rnvp':
            self.policy = RealNVPPolicy(
                observation_space, action_space,
                n_hidden=n_hidden, device=self.device,
                noise_type='gaussian',
            ).to(self.device)
        elif self.policy_type == 'sgrnvp':
            self.policy = SequantialGaussRealNVPPolicy(
                observation_space, action_space,
                n_hidden=n_hidden, device=self.device,
            ).to(self.device)


        self.optimizer_policy = torch.optim.Adam(self.policy.parameters(), lr=lr_policy)
        print(f'policy = {self.policy}')

        self.n_projection_samples = n_projection_samples
        self.n_projection_samples_per_dim = n_projection_samples_per_dim
        if n_projection_samples_per_dim is None:
            self.n_projection_samples = n_projection_samples
        else:
            self.n_projection_samples = n_projection_samples_per_dim * self.dim_ac

        ''' Prepare entropy '''
        self.autotune_ent_coeff = autotune_ent_coeff
        self.target_entropy = - torch.prod(torch.Tensor(self.action_space.shape).to(self.device)).item()
        if self.autotune_ent_coeff:
            '''
                the value of target entropy, − dim (A), is based on [Haarnoja 19; Appendix D.]
                log \alpha is stored as a parameter and used with exponentiating.
            '''
            self.max_coeff_entropy = max_coeff_entropy
            self.log_coeff_entropy = torch.zeros(1, requires_grad=True, device=self.device)
            self.optimizer_entropy = torch.optim.Adam([self.log_coeff_entropy], lr=lr)
            print(f'target_entropy    = {self.target_entropy}')
            print(f'log_coeff_entropy = {self.log_coeff_entropy}')
        else:
            # self.coeff_entropy = coeff_entropy
            self.coeff_entropy = torch.tensor([coeff_entropy], dtype=torch.float32, device=self.device)


    def _build_network(self):
        if self.qf_emulate_clean_rl:
            return SoftQNetwork(self.observation_space, self.action_space).to(self.device)
        else:
            return ActionValueNetwork4ContinuousAction(
                self.observation_space, self.action_space,
                n_hidden=self.n_hidden, layer_norm=self.layer_norm,
                device=self.device,
            ).to(self.device)


    def act(
        self,
        observation: ty.Union[torch.tensor, np.ndarray],
        greedy: bool=False, as_tensor=False,
    ) -> ty.Union[torch.tensor, np.ndarray]:
        if not isinstance(observation, torch.Tensor):
            if isinstance(observation, int):
                observation = np.array([observation])
            observation = torch.as_tensor(observation, dtype=torch.float32, device=self.device)
        if not isinstance(self.observation_space, spaces.Discrete) \
            and observation.shape == self.observation_space.shape:
            observation = torch.unsqueeze(observation, 0)

        action = self.policy.act(observation, greedy=greedy, as_tensor=as_tensor)
        action = action.squeeze()
        if action.shape==():
            action = action.reshape((1,))

        return action


    def predict_as_np(self,
                      observation: ty.Union[torch.tensor, np.ndarray],
                      action: ty.Union[torch.tensor, np.ndarray],
        ) -> ty.Tuple[np.ndarray, np.ndarray]:
        if not isinstance(observation, torch.Tensor):
            observation = torch.as_tensor(observation, dtype=torch.float32, device=self.device)
        if not isinstance(action, torch.Tensor):
            action = torch.as_tensor(action, dtype=torch.float32, device=self.device)
        if not isinstance(self.observation_space, spaces.Discrete) \
            and observation.shape == self.observation_space.shape:
            observation = torch.unsqueeze(observation, 0)
        if not isinstance(self.action_space, spaces.Discrete) \
            and action.shape == self.action_space.shape:
            action = torch.unsqueeze(action, 0)

        q_sa = self.q_net(observation, action)


        return q_sa


    def compute_entropy(self, ob: torch.tensor):
        with torch.no_grad():
            log_pi = self.policy.log_prob(ob)
            entropy = - log_pi.mean()
        return entropy


    def get_coeff_entropy(self):
        if self.autotune_ent_coeff:
            coeff_entropy = self.log_coeff_entropy.exp().clamp(0, self.max_coeff_entropy)
        else:
            coeff_entropy = self.coeff_entropy
        return coeff_entropy


    def compute_target(self,
                       ob: torch.tensor,
                       ac: torch.tensor,
                       rew: torch.tensor,
                       next_ob: torch.tensor,
                       done: torch.tensor):

        batch_size = ob.shape[0]

        with torch.no_grad():
            next_ac_pi, next_log_pi = self.policy.act(next_ob, as_tensor=True, with_log_prob=True)
            next_log_pi = next_log_pi.view(batch_size, 1)

            next_q_sa = self.target_q_net(next_ob, next_ac_pi)
            if self.use_double:
                next_q_sa2 = self.target_q_net2(next_ob, next_ac_pi)
                next_q_sa = torch.min(next_q_sa, next_q_sa2)
            if not self.qf_emulate_clean_rl:
                next_q_sa = next_q_sa.unsqueeze(-1)
            assert next_q_sa.shape == (batch_size, 1), f'{next_q_sa.shape} != {(batch_size, 1)}'

            ''' entropy bonus for soft policy iteration '''
            coeff_entropy = self.get_coeff_entropy()
            ''' soft target '''
            next_soft_q_sa = next_q_sa - coeff_entropy * next_log_pi

            target = rew + self.gamma * (1.0 - done) * next_soft_q_sa
            assert target.shape == (batch_size, 1)

        return target


    def compute_critic_loss(
            self,
            ob: torch.tensor,
            ac: torch.tensor,
            rew: torch.tensor,
            next_ob: torch.tensor,
            done: torch.tensor,
            reward_min: float = None,
            reward_max: float = None,
        ):

        batch_size = ob.shape[0]

        q_sa = self.q_net(ob, ac)
        if not self.qf_emulate_clean_rl:
            q_sa = q_sa.unsqueeze(-1)
        target = self.compute_target(ob, ac, rew, next_ob, done)
        assert target.shape == q_sa.shape
        td_errors = target - q_sa
        soft_pe_loss = F.mse_loss(target, q_sa)

        if self.use_double:
            q_sa2 = self.q_net2(ob, ac)
            if not self.qf_emulate_clean_rl:
                q_sa2 = q_sa2.unsqueeze(-1)
            soft_pe_loss += F.mse_loss(target, q_sa2)

        return soft_pe_loss, td_errors



    def compute_entropy_loss(self, ob: torch.tensor):
        with torch.no_grad():
            log_pi = self.policy.log_prob(ob)

        coeff_entropy = self.log_coeff_entropy.exp().clamp(0, self.max_coeff_entropy)
        entropy_loss = - (coeff_entropy * (log_pi + self.target_entropy).detach()).mean()

        return entropy_loss


    def compute_actor_loss(
            self,
            ob: torch.tensor,
            policy: torch.nn.Module,
        ):
        batch_size = ob.shape[0]

        if self.n_projection_samples > 1:
            ob_repeat = ob.unsqueeze(1).repeat(1, self.n_projection_samples, 1).view(batch_size * self.n_projection_samples, self.dim_ob)
        else:
            ob_repeat = ob

        ''' policy evaluation '''
        ac_pi, log_pi = policy.act(ob_repeat, as_tensor=True, with_log_prob=True)

        q_sa = self.q_net(ob_repeat, ac_pi)
        if self.use_double:
            q_sa2 = self.q_net2(ob_repeat, ac_pi)
            q_sa = torch.min(q_sa, q_sa2)
        with torch.no_grad():
            coeff_entropy = self.get_coeff_entropy()

        if self.n_projection_samples > 1:
            log_pi = log_pi.view(batch_size, self.n_projection_samples)
            q_sa = q_sa.view(batch_size, self.n_projection_samples)
            soft_greedy_loss = (coeff_entropy * log_pi - q_sa).mean(-1).mean()
        else:
            log_pi = log_pi.view(batch_size, 1)
            if not self.qf_emulate_clean_rl:
                q_sa = q_sa.view(batch_size, 1)
            soft_greedy_loss = (coeff_entropy * log_pi - q_sa).mean()

        return soft_greedy_loss



    def update_target_network(self, update_policy: bool = True):
        '''
            theta' <- polyak * theta' * (1 - polyak) * theta
        '''
        with torch.no_grad():
            polyak_update(self.q_net, self.target_q_net, self.polyak)
            if self.use_double:
                polyak_update(self.q_net2, self.target_q_net2, self.polyak)


    def grad_clip_critic(self):
        if self.grad_clipping is not None:
            torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.grad_clipping)
            if self.use_double:
                torch.nn.utils.clip_grad_norm_(self.q_net2.parameters(), self.grad_clipping)


    def grad_clip_actor(self):
        if self.grad_clipping is not None:
            torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.grad_clipping)


    def grad_clip_entropy(self):
        if self.grad_clipping is not None:
            torch.nn.utils.clip_grad_norm_([self.log_coeff_entropy], self.grad_clipping)


    def get_mini_batch(self, use_all_samples: bool = False):
        if use_all_samples:
            batch = self.buffer.sample_uniformly()
            # print(batch)
        else:
            batch = self.buffer.sample(batch_size=self.batch_size)
        ob, ac, rew, ob2, done = \
            batch['ob'], batch['ac'], batch['rew'], batch['ob2'], batch['done']
        ob   = torch.as_tensor(ob, dtype=torch.float32, device=self.device)
        ac	 = torch.as_tensor(ac, dtype=torch.float32, device=self.device)
        ob2  = torch.as_tensor(ob2, dtype=torch.float32, device=self.device)
        rew	 = torch.as_tensor(rew, dtype=torch.float32, device=self.device).view(-1, 1)
        done = torch.as_tensor(done, dtype=torch.float32, device=self.device).view(-1, 1)
        return ob, ac, ob2, rew, done


    def update_critic(
        self,
        ob: torch.tensor,
        ac: torch.tensor,
        rew: torch.tensor,
        next_ob: torch.tensor,
        done: torch.tensor,
    ):
        self.optimizer.zero_grad()
        critic_loss, error = self.compute_critic_loss(ob, ac, rew, next_ob, done)
        critic_loss.backward()
        self.grad_clip_critic()
        self.optimizer.step()
        error = error.detach().abs().mean().item()
        critic_loss = critic_loss.detach().item()
        return critic_loss, error


    def update_actor(
        self,
        ob: torch.tensor,
    ):
        self.optimizer_policy.zero_grad()
        actor_loss = self.compute_actor_loss(ob, self.policy)
        actor_loss.backward()
        self.grad_clip_actor()
        self.optimizer_policy.step()
        actor_loss = actor_loss.detach().item()
        return actor_loss


    def update_entropy_coeff(
        self,
        ob: torch.tensor,
    ):
        self.optimizer_entropy.zero_grad()
        entropy_loss = self.compute_entropy_loss(ob)
        entropy_loss.backward()
        self.grad_clip_entropy()
        self.optimizer_entropy.step()
        entropy_loss = entropy_loss.detach().item()
        return entropy_loss


    def update(self, use_all_samples: bool = False, update_policy: bool = True):

        ob, ac, ob2, rew, done = self.get_mini_batch(use_all_samples)

        critic_loss, error = self.update_critic(ob, ac, rew, ob2, done)

        if update_policy:
            actor_loss = self.update_actor(ob)

            if self.autotune_ent_coeff:
                entropy_loss = self.update_entropy_coeff(ob)

        self.update_target_network(update_policy)

        return critic_loss, error


    def save(self, model_dir: Path, tag: str, verbose=False):

        def save_one_model(module, save_path):
            torch.save(module.state_dict(), save_path)
            if verbose:
                print(f'saved as {save_path}')

        save_one_model(self.policy, model_dir/f'policy_{tag}.pt')
        save_one_model(self.q_net, model_dir/f'q_net_{tag}.pt')
        save_one_model(self.target_q_net, model_dir/f'target_q_net_{tag}.pt')
        if self.use_double:
            save_one_model(self.q_net2, model_dir/f'q_net2_{tag}.pt')
            save_one_model(self.target_q_net2, model_dir/f'target_q_net2_{tag}.pt')


    def load(self, model_dir: Path, tag: str, verbose=False):

        def load_one_model(module, load_path):
            module.load_state_dict(torch.load(load_path))
            if verbose:
                print(f'loaded from {load_path}')

        load_one_model(self.policy, model_dir/f'policy_{tag}.pt')
        load_one_model(self.q_net, model_dir/f'q_net_{tag}.pt')
        load_one_model(self.target_q_net, model_dir/f'target_q_net_{tag}.pt')
        if self.use_double:
            load_one_model(self.q_net2, model_dir/f'q_net2_{tag}.pt')
            load_one_model(self.target_q_net2, model_dir/f'target_q_net2_{tag}.pt')
