
import dataclasses as dc
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import typing as ty

from copy import copy, deepcopy
from gym import spaces
from pathlib import Path

from .base import ReplayBuffer, polyak_update, check_gradient, check_param_norm
from .policy import RealNVPPolicy



@dc.dataclass
class Param4SAC:

    # for commom
    batch_size      : int = 256
    buffer_size     : int = int(1e6)
    q_lr              : float = 1e-3
    gamma           : float = 0.99
    seed            : int = 7
    cuda            : bool = True
    # n_hidden        : list = [128, 128]
    tau          : float = 0.995
    # target_update_every: int = 1

    # for actor
    policy_lr: float = 3e-4
    policy_type: str = 'gaussian'
    autotune_ent_coeff: bool = True
    coeff_entropy: float = 0.1
    action_scaler: str = 'spinningup'
    bonus_squash: bool = False



class SoftQNetwork(torch.nn.Module):
    def __init__(self, observation_space, action_space):
        super().__init__()
        self.fc1 = nn.Linear(np.array(observation_space.shape).prod() + np.prod(action_space.shape), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    def forward(self, x, a):
        x = torch.cat([x, a], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


LOG_STD_MAX = 2
LOG_STD_MIN = -5


class Actor(torch.nn.Module):
    def __init__(self, observation_space, action_space):
        super().__init__()
        self.fc1 = nn.Linear(np.array(observation_space.shape).prod(), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mean = nn.Linear(256, np.prod(action_space.shape))
        self.fc_logstd = nn.Linear(256, np.prod(action_space.shape))
        # action rescaling
        self.register_buffer(
            "action_scale", torch.tensor((action_space.high - action_space.low) / 2.0, dtype=torch.float32)
        )
        self.register_buffer(
            "action_bias", torch.tensor((action_space.high + action_space.low) / 2.0, dtype=torch.float32)
        )

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mean = self.fc_mean(x)
        log_std = self.fc_logstd(x)
        log_std = torch.tanh(log_std)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)  # From SpinUp / Denis Yarats

        return mean, log_std

    def get_action(self, x):
        mean, log_std = self(x)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean


class SoftActorCritic(torch.nn.Module):

    '''
        Soft Actor Ctiric
    '''

    def __init__(self,
                 observation_space, action_space,
                 batch_size: int = 16,
                 q_lr: float = 1e-2,
                 gamma: float = 0.99,
                 seed: int = 13,
                 buffer_size: int = int(1e4),
                 cuda: bool = False,
                 tau: float = 0.005,
                 # actor
                 policy_lr: float = 1e-3,
                 policy_type: str = 'gaussian',
                 autotune_ent_coeff: bool = True,
                 coeff_entropy: float = 0.9,
                 action_scaler: str = 'spinningup',
                 bonus_squash: bool = False,
             ):
        super(SoftActorCritic, self).__init__()

        self.observation_space = observation_space
        self.action_space = action_space

        self.seed = seed
        self.dice = np.random.RandomState(seed)
        self.device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")
        self.gamma = gamma
        self.batch_size = batch_size
        self.q_lr = q_lr
        self.policy_lr = policy_lr
        self.policy_type = policy_type
        self.tau = tau
        self.autotune_ent_coeff = autotune_ent_coeff
        self.bonus_squash = bonus_squash

        self.buffer = ReplayBuffer(self.observation_space, self.action_space, buffer_size)

        if self.policy_type == 'gaussian':
            self.actor = Actor(observation_space, action_space).to(self.device)
        elif self.policy_type == 'rnvp':
            self.actor = RealNVPPolicy(
                observation_space, action_space,
                n_hidden=[256, 256], device=self.device,
                noise_type='gaussian',
            ).to(self.device)
        self.qf1 = SoftQNetwork(observation_space, action_space).to(self.device)
        self.qf2 = SoftQNetwork(observation_space, action_space).to(self.device)
        self.qf1_target = SoftQNetwork(observation_space, action_space).to(self.device)
        self.qf2_target = SoftQNetwork(observation_space, action_space).to(self.device)
        self.qf1_target.load_state_dict(self.qf1.state_dict())
        self.qf2_target.load_state_dict(self.qf2.state_dict())
        self.q_optimizer = optim.Adam(list(self.qf1.parameters()) + list(self.qf2.parameters()), lr=self.q_lr)
        self.actor_optimizer = optim.Adam(list(self.actor.parameters()), lr=self.policy_lr)

        # Automatic entropy tuning
        if self.autotune_ent_coeff:
            self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
            self.alpha = self.log_alpha.exp().item()
            self.a_optimizer = optim.Adam([self.log_alpha], lr=self.q_lr)
        else:
            self.alpha = self.alpha

        print(f'qf1 = {self.qf1}')
        print(f'actor = {self.actor}')


    def act(self,
            observation: ty.Union[torch.tensor, np.ndarray],
            greedy: bool=False, as_tensor=False) -> ty.Union[torch.tensor, np.ndarray]:
        if not isinstance(observation, torch.Tensor):
            if isinstance(observation, int):
                observation = np.array([observation])
            observation = torch.as_tensor(observation, dtype=torch.float32, device=self.device)
        if not isinstance(self.observation_space, spaces.Discrete) \
            and observation.shape == self.observation_space.shape:
            observation = torch.unsqueeze(observation, 0)

        if self.policy_type == 'gaussian':
            action, log_prob, mean = self.actor.get_action(observation)
        elif self.policy_type == 'rnvp':
            action = self.actor.act(observation, with_log_prob=False)
            mean = action

        action = action.squeeze()
        mean = mean.squeeze()
        if greedy:
            return mean.cpu().detach().numpy()
        else:
            return action.cpu().detach().numpy()


    def predict_as_np(self,
                      observation: ty.Union[torch.tensor, np.ndarray],
                      action: ty.Union[torch.tensor, np.ndarray],
        ) -> ty.Tuple[np.ndarray, np.ndarray]:
        if not isinstance(observation, torch.Tensor):
            observation = torch.as_tensor(observation, dtype=torch.float32, device=self.device)
        if not isinstance(action, torch.Tensor):
            action = torch.as_tensor(action, dtype=torch.float32, device=self.device)
        if not isinstance(self.observation_space, spaces.Discrete) \
            and observation.shape == self.observation_space.shape:
            observation = torch.unsqueeze(observation, 0)
        if not isinstance(self.action_space, spaces.Discrete) \
            and action.shape == self.action_space.shape:
            action = torch.unsqueeze(action, 0)

        q_sa = self.qf1(observation, action)

        return q_sa


    def compute_target(self,
                       ob: torch.tensor,
                       ac: torch.tensor,
                       rew: torch.tensor,
                       next_ob: torch.tensor,
                       done: torch.tensor):

        batch_size = ob.shape[0]

        with torch.no_grad():
            if self.policy_type == 'gaussian':
                next_state_actions, next_state_log_pi, _ = self.actor.get_action(next_ob)
            elif self.policy_type == 'rnvp':
                next_state_actions, next_state_log_pi = self.actor.act(next_ob, with_log_prob=True)

            next_q_sa1 = self.qf1_target(next_ob, next_state_actions)
            next_q_sa2 = self.qf2_target(next_ob, next_state_actions)
            min_qf_next_target = torch.min(next_q_sa1, next_q_sa2)
            bonus = self.alpha * next_state_log_pi
            if self.bonus_squash:
                bonus = torch.tanh(bonus*0.1)
            next_soft_q_sa = min_qf_next_target - bonus
            target = rew.flatten() + (1 - done.flatten()) * self.gamma * (next_soft_q_sa).view(-1)

        return target


    def compute_critic_loss(
            self,
            ob: torch.tensor,
            ac: torch.tensor,
            rew: torch.tensor,
            next_ob: torch.tensor,
            done: torch.tensor,
        ):

        batch_size = ob.shape[0]

        target = self.compute_target(ob, ac, rew, next_ob, done)
        qf1_a_values = self.qf1(ob, ac).view(-1)
        qf2_a_values = self.qf2(ob, ac).view(-1)
        qf1_loss = F.mse_loss(qf1_a_values, target)
        qf2_loss = F.mse_loss(qf2_a_values, target)
        qf_loss = qf1_loss + qf2_loss

        return qf_loss, target - qf1_a_values



    def compute_entropy_loss(self, ob: torch.tensor):
        '''
            should we use next action sampled here to compute critic graduent?
            or should we use actions in batch? <- this might be incorrect
        '''
        with torch.no_grad():
            if self.policy_type == 'gaussian':
                _, log_pi, _ = self.actor.get_action(ob)
            elif self.policy_type == 'rnvp':
                _, log_pi = self.actor.act(ob, with_log_prob=True)

        alpha_loss = (-self.log_alpha.exp() * (log_pi + self.target_entropy)).mean()

        return alpha_loss



    def compute_actor_loss(
            self,
            ob: torch.tensor,
        ):
        batch_size = ob.shape[0]

        if self.policy_type == 'gaussian':
            pi, log_pi, _ = self.actor.get_action(ob)
        elif self.policy_type == 'rnvp':
            pi, log_pi = self.actor.act(ob, with_log_prob=True)
        qf1_pi = self.qf1(ob, pi)
        qf2_pi = self.qf2(ob, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)
        actor_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

        return actor_loss


    def update_target_network(self):
        '''
            theta' <- polyak * theta' * (1 - polyak) * theta
        '''
        for param, target_param in zip(self.qf1.parameters(), self.qf1_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.qf2.parameters(), self.qf2_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)


    def update(self, use_all_samples: bool = False, update_policy: bool = True):
        if use_all_samples:
            batch = self.buffer.sample_uniformly()
            # print(batch)
        else:
            batch = self.buffer.sample(batch_size=self.batch_size)
        ob, ac, rew, ob2, done = \
            batch['ob'], batch['ac'], batch['rew'], batch['ob2'], batch['done']
        ob   = torch.as_tensor(ob, dtype=torch.float32, device=self.device)
        ac	 = torch.as_tensor(ac, dtype=torch.float32, device=self.device)
        ob2  = torch.as_tensor(ob2, dtype=torch.float32, device=self.device)
        rew	 = torch.as_tensor(rew, dtype=torch.float32, device=self.device).view(-1, 1)
        done = torch.as_tensor(done, dtype=torch.float32, device=self.device).view(-1, 1)

        if use_all_samples:
            batch_size = len(ob)

        DEBUGGING = False
        # DEBUGGING = True
        if DEBUGGING:
            check_param_norm([self.qf1, self.qf2])
        self.q_optimizer.zero_grad()
        critic_loss, error = self.compute_critic_loss(ob, ac, rew, ob2, done)
        critic_loss.backward()
        if DEBUGGING:
            check_gradient([self.qf1, self.qf2])
        self.q_optimizer.step()
        error = error.detach().abs().mean().item()
        critic_loss = critic_loss.detach().item()
        if DEBUGGING:
            check_param_norm([self.qf1, self.qf2])

        if update_policy:
            if DEBUGGING:
                check_param_norm([self.actor])
            self.actor_optimizer.zero_grad()
            actor_loss = self.compute_actor_loss(ob)
            actor_loss.backward()
            if DEBUGGING:
                check_gradient([self.actor])
            self.actor_optimizer.step()
            actor_loss = actor_loss.detach().item()
            if DEBUGGING:
                check_param_norm([self.actor])

            if self.autotune_ent_coeff:
                self.a_optimizer.zero_grad()
                entropy_loss = self.compute_entropy_loss(ob)
                entropy_loss.backward()
                self.a_optimizer.step()
                self.alpha = self.log_alpha.exp().item()
                # print(entropy_loss)
                entropy_loss = entropy_loss.detach().item()

        self.update_target_network()
        # print(f'eloss = {entropy_loss}')
        # print(f'error = {error}')
        # print(f'closs = {critic_loss}')
        # print(f'aloss = {actor_loss}')

        return critic_loss, error


    def save(self, model_dir: Path, tag: str, verbose=False):

        def save_one_model(module, save_path):
            torch.save(module.state_dict(), save_path)
            if verbose:
                print(f'saved as {save_path}')

        save_one_model(self.actor, model_dir/f'actor_{tag}.pt')
        save_one_model(self.qf1, model_dir/f'qf1_{tag}.pt')
        save_one_model(self.qf1_target, model_dir/f'qf1_target_{tag}.pt')
        save_one_model(self.qf2, model_dir/f'qf2_{tag}.pt')
        save_one_model(self.qf2_target, model_dir/f'qf2_target_{tag}.pt')


    def load(self, model_dir: Path, tag: str, verbose=False):

        def load_one_model(module, load_path):
            module.load_state_dict(torch.load(load_path))
            if verbose:
                print(f'loaded from {load_path}')

        load_one_model(self.actor, model_dir/f'actor_{tag}.pt')
        load_one_model(self.qf1, model_dir/f'qf1_{tag}.pt')
        load_one_model(self.qf1_target, model_dir/f'qf1_target_{tag}.pt')
        load_one_model(self.qf2, model_dir/f'qf2_{tag}.pt')
        load_one_model(self.qf2_target, model_dir/f'qf2_target_{tag}.pt')


    def get_state_dict(self) -> dict:
        checkpoint = dict()
        checkpoint['buffer'] = self.buffer
        checkpoint['qf1_state_dict'] = self.qf1.state_dict()
        checkpoint['qf2_state_dict'] = self.qf2.state_dict()
        checkpoint['qf1_target_state_dict'] = self.qf1_target.state_dict()
        checkpoint['qf2_target_state_dict'] = self.qf2_target.state_dict()
        checkpoint['q_optimizer_state_dict'] = self.q_optimizer.state_dict()
        checkpoint['actor'] = self.actor.state_dict()
        checkpoint['actor_optimizer'] = self.actor_optimizer.state_dict()
        if self.autotune_ent_coeff:
            checkpoint['log_alpha'] = self.log_alpha
            checkpoint['a_optimizer'] = self.a_optimizer.state_dict()
        return checkpoint


    def set_state_dict(self, checkpoint: dict):
        self.buffer = checkpoint['buffer']
        self.qf1.load_state_dict(checkpoint['qf1_state_dict'])
        self.qf2.load_state_dict(checkpoint['qf2_state_dict'])
        self.qf1_target.load_state_dict(checkpoint['qf1_target_state_dict'])
        self.qf2_target.load_state_dict(checkpoint['qf2_target_state_dict'])
        self.q_optimizer.load_state_dict(checkpoint['q_optimizer_state_dict'])
        self.actor.load_state_dict(checkpoint['actor'])
        self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
        if self.autotune_ent_coeff:
            self.log_alpha = checkpoint['log_alpha']
            self.a_optimizer.load_state_dict(checkpoint['a_optimizer'])
