from typing import Dict, Tuple, List, Union
import numpy as np
import torch
import torch.optim          as optim
import torch.nn             as nn
import torch.nn.functional  as F

from value_guided_data_filtering.misc.buffer   import Buffer
from value_guided_data_filtering.misc.utils    import soft_update, confirm_path_exist
from value_guided_data_filtering.model.policy  import SquashedGaussianPolicy
from value_guided_data_filtering.model.value   import QEnsemble


class SAC:
    def __init__(self, config: Dict) -> None:
        self.config             =       config
        self.model_config       =       config['model_config']
        self.s_dim              =       self.model_config['s_dim']
        self.a_dim              =       self.model_config['a_dim']
        self.device             =       config['device']
        self.exp_path           =       config['exp_path']

        self.lr                 =       config['lr']
        self.gamma              =       config['gamma']
        self.tau                =       config['tau']
        self.alpha              =       config['alpha']
        self.training_delay     =       config['training_delay']
        self.batch_size         =       config['batch_size']

        self.training_count =       0
        self.loss_log       =       {}

        # adaptive alpha
        if self.alpha is None:
            self.train_alpha    =   True
            self.log_alpha      =   torch.zeros(1, requires_grad=True, device=self.device)
            self.alpha          =   torch.exp(self.log_alpha)
            self.target_entropy =   - torch.tensor(config['model_config']['a_dim'], dtype=torch.float64)
            self.optimizer_alpha=   optim.Adam([self.log_alpha], lr= self.lr)
        else:
            self.train_alpha    =   False

        # policy
        self.policy         =       SquashedGaussianPolicy(
            s_dim           =   self.s_dim,
            a_dim           =   self.a_dim,
            hidden_layers   =   self.model_config['policy_hiddens'],
            inner_nonlinear =   self.model_config['policy_nonlinear'],
            log_std_min     =   self.model_config['policy_log_std_min'],
            log_std_max     =   self.model_config['policy_log_std_max'],
            initializer     =   self.model_config['policy_initializer']
        ).to(self.device)
        self.optimizer_policy   =   optim.Adam(self.policy.parameters(), self.lr)
        
        # value functions
        self.QFunction      =       QEnsemble(
            ensemble_size   =   2,
            s_dim           =   self.s_dim,
            a_dim           =   self.a_dim,
            hiddens         =   self.model_config['value_hiddens'],
            inner_nonlinear =   self.model_config['value_nonlinear'],
            initializer     =   self.model_config['value_initializer']
        ).to(self.device)
        self.QFunction_tar  =      QEnsemble(
            ensemble_size   =   2,
            s_dim           =   self.s_dim,
            a_dim           =   self.a_dim,
            hiddens         =   self.model_config['value_hiddens'],
            inner_nonlinear =   self.model_config['value_nonlinear'],
            initializer     =   self.model_config['value_initializer']
        ).to(self.device)
        self.QFunction_tar.load_state_dict(self.QFunction.state_dict())
        self.optimizer_value    =   optim.Adam(self.QFunction.parameters(), self.lr)
        self.buffer             =   Buffer(config['buffer_size'])

    def sample_action(self, s: np.array, with_noise: bool) -> np.array:
        with torch.no_grad():
            s               =   torch.from_numpy(s).float().to(self.device)
            action          =   self.policy.sample_action(s, with_noise)
        return action.detach().cpu().numpy()

    def _pack_batch(
        self, 
        s_batch: np.array, 
        a_batch: np.array, 
        r_batch: np.array, 
        done_batch: np.array, 
        next_s_batch: np.array,
    ) -> Tuple[torch.tensor]:
        s_batch         =       torch.from_numpy(s_batch).float().to(self.device)
        a_batch         =       torch.from_numpy(a_batch).float().to(self.device)
        r_batch         =       torch.from_numpy(r_batch).float().to(self.device)
        done_batch      =       torch.from_numpy(done_batch).float().to(self.device)
        next_s_batch    =       torch.from_numpy(next_s_batch).float().to(self.device)
        return s_batch, a_batch, r_batch, done_batch, next_s_batch

    def train(self) -> None:
        if len(self.buffer) < self.batch_size:
            return
            
        s_batch, a_batch, r_batch, done_batch, next_s_batch     =   self.buffer.sample(self.batch_size)
        s_batch, a_batch, r_batch, done_batch, next_s_batch     =   self._pack_batch(
            s_batch, a_batch, r_batch, done_batch, next_s_batch
        )

        # training value function
        loss_value      =       self._compute_value_loss(s_batch, a_batch, r_batch, done_batch, next_s_batch)
        self.optimizer_value.zero_grad()
        loss_value.backward()
        self.optimizer_value.step()
        self.loss_log['loss_value'] = loss_value.detach().cpu().item()

        if self.training_count % self.training_delay == 0:
            # train policy
            loss_policy, new_a_log_prob = self._compute_policy_loss(s_batch)
            self.optimizer_policy.zero_grad()
            loss_policy.backward()
            self.optimizer_policy.step()
            self.loss_log['loss_policy']    = loss_policy.cpu().item()

            if self.train_alpha:
                loss_alpha  =  (- torch.exp(self.log_alpha) * (new_a_log_prob.detach() + self.target_entropy)).mean()
                self.optimizer_alpha.zero_grad()
                loss_alpha.backward()
                self.optimizer_alpha.step()
                self.alpha = torch.exp(self.log_alpha)
                
                self.loss_log['alpha'] = self.alpha.detach().cpu().item()
                self.loss_log['loss_alpha'] = loss_alpha.detach().cpu().item()

        # soft update target networks
        soft_update(self.QFunction, self.QFunction_tar, self.tau)
        self.training_count += 1

    def _compute_value_loss(self, s: torch.tensor, a: torch.tensor, r: torch.tensor, done: torch.tensor, next_s: torch.tensor) -> torch.tensor:
        with torch.no_grad():
            next_a, next_a_log_prob, _          = self.policy(next_s)
            next_sa_value_1, next_sa_value_2    = self.QFunction_tar(next_s, next_a)
            next_sa_value                       = torch.min(next_sa_value_1, next_sa_value_2) 
            target_value                        = r + (1 - done) * self.gamma * (next_sa_value - self.alpha * next_a_log_prob)
        sa_value_1, sa_value_2 = self.QFunction(s, a)
        return F.mse_loss(sa_value_1, target_value) + F.mse_loss(sa_value_2, target_value)

    def _compute_policy_loss(self, s: torch.tensor) -> torch.tensor:
        a, a_log_prob, _    =   self.policy(s)
        q1_value, q2_value  =   self.QFunction(s, a)
        q_value             =   torch.min(q1_value, q2_value)
        neg_entropy         =   a_log_prob
        return (- q_value + self.alpha * neg_entropy).mean(), a_log_prob

    def evaluate(self, env, num_episode: int) -> List:
        total_r = []
        for _ in range(num_episode):
            s = env.reset()
            done = False
            episode_r = 0
            while not done:
                a = self.sample_action(s, False)
                s_, r, done, _ = env.step(a)
                episode_r += r
                s = s_
            total_r.append(episode_r)
        return total_r

    def save_all_module(self, remark: str, model_path: str = None) -> None:
        if not model_path:
            model_path = self.exp_path + 'model/'
            confirm_path_exist(model_path)
        model_path = model_path + f'{remark}'
        torch.save({
            'policy':           self.policy.state_dict(),
            'value':            self.QFunction.state_dict(),
        }, model_path)
        print(f"------- All modules saved to {model_path} ----------")

    def load_all_module(self, checkpoint_path: str) -> None:
        state_dict = torch.load(checkpoint_path)
        self.policy.load_state_dict(state_dict['policy'])
        self.QFunction.load_state_dict(state_dict['value'])
        self.QFunction_tar.load_state_dict(self.QFunction.state_dict())
        print(f"------- Loaded all modules from {checkpoint_path} ----------")