
import dataclasses as dc
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import typing as ty

from copy import copy, deepcopy
from gym import spaces
from pathlib import Path

from .base import ReplayBuffer, polyak_update, check_gradient, check_param_norm
# from .networks import ActionValueNetwork4ContinuousAction
# from .policy import SquashedGaussianPolicy,



@dc.dataclass
class Param4TD3:

    # for commom
    batch_size      : int = 256
    buffer_size     : int = int(1e6)
    gamma           : float = 0.99
    seed            : int = 7
    cuda            : bool = True
    tau          : float = 0.995
    policy_lr: float = 3e-4
    q_lr: float = 3e-4
    policy_noise: float = 0.2
    exploration_noise: float = 0.1
    noise_clip: float = 0.5


class QNetwork(nn.Module):
    def __init__(self, observation_space, action_space):
        super().__init__()
        self.fc1 = nn.Linear(np.array(observation_space.shape).prod() + np.prod(action_space.shape), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    def forward(self, x, a):
        x = torch.cat([x, a], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class Actor(nn.Module):
    def __init__(self, observation_space, action_space):
        super().__init__()
        self.fc1 = nn.Linear(np.array(observation_space.shape).prod(), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mu = nn.Linear(256, np.prod(action_space.shape))
        # action rescaling
        self.register_buffer(
            "action_scale", torch.tensor((action_space.high - action_space.low) / 2.0, dtype=torch.float32)
        )
        self.register_buffer(
            "action_bias", torch.tensor((action_space.high + action_space.low) / 2.0, dtype=torch.float32)
        )

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.tanh(self.fc_mu(x))
        return x * self.action_scale + self.action_bias



class TD3(torch.nn.Module):

    '''
        TD3 from Clean RL
    '''

    def __init__(self,
                 observation_space, action_space,
                 batch_size: int = 16,
                 q_lr: float = 1e-2,
                 gamma: float = 0.99,
                 seed: int = 13,
                 buffer_size: int = int(1e4),
                 cuda: bool = False,
                 tau: float = 0.005,
                 # actor
                 policy_lr: float = 1e-3,
                 policy_noise: float = 0.2,
                 exploration_noise: float = 0.1,
                 noise_clip: float = 0.5,
             ):
        super(TD3, self).__init__()

        self.observation_space = observation_space
        self.action_space = action_space

        self.seed = seed
        self.dice = np.random.RandomState(seed)
        self.device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")
        self.gamma = gamma
        self.batch_size = batch_size
        self.q_lr = q_lr
        self.policy_lr = policy_lr
        self.tau = tau

        self.policy_noise = policy_noise
        self.exploration_noise = exploration_noise
        self.noise_clip = noise_clip

        self.buffer = ReplayBuffer(self.observation_space, self.action_space, buffer_size)
        self.actor = Actor(self.observation_space, self.action_space).to(self.device)
        self.qf1 = QNetwork(self.observation_space, self.action_space).to(self.device)
        self.qf2 = QNetwork(self.observation_space, self.action_space).to(self.device)
        self.qf1_target = QNetwork(self.observation_space, self.action_space).to(self.device)
        self.qf2_target = QNetwork(self.observation_space, self.action_space).to(self.device)
        self.target_actor = Actor(self.observation_space, self.action_space).to(self.device)
        self.target_actor.load_state_dict(self.actor.state_dict())
        self.qf1_target.load_state_dict(self.qf1.state_dict())
        self.qf2_target.load_state_dict(self.qf2.state_dict())
        self.q_optimizer = optim.Adam(list(self.qf1.parameters()) + list(self.qf2.parameters()), lr=self.q_lr)
        self.actor_optimizer = optim.Adam(list(self.actor.parameters()), lr=self.policy_lr)

        print(f'qf1 = {self.qf1}')
        print(f'actor = {self.actor}')


    def act(self,
            observation: ty.Union[torch.tensor, np.ndarray],
            greedy: bool=False, as_tensor=False) -> ty.Union[torch.tensor, np.ndarray]:
        if not isinstance(observation, torch.Tensor):
            if isinstance(observation, int):
                observation = np.array([observation])
            observation = torch.as_tensor(observation, dtype=torch.float32, device=self.device)
        if not isinstance(self.observation_space, spaces.Discrete) \
            and observation.shape == self.observation_space.shape:
            observation = torch.unsqueeze(observation, 0)

        action = self.actor(torch.Tensor(observation).to(self.device))
        if not greedy:
            action += torch.normal(0, self.actor.action_scale * self.exploration_noise)
        action = action.cpu().detach().numpy().clip(self.action_space.low, self.action_space.high)
        action = action.squeeze()

        return action


    def predict_as_np(self,
                      observation: ty.Union[torch.tensor, np.ndarray],
                      action: ty.Union[torch.tensor, np.ndarray],
        ) -> ty.Tuple[np.ndarray, np.ndarray]:
        if not isinstance(observation, torch.Tensor):
            observation = torch.as_tensor(observation, dtype=torch.float32, device=self.device)
        if not isinstance(action, torch.Tensor):
            action = torch.as_tensor(action, dtype=torch.float32, device=self.device)
        if not isinstance(self.observation_space, spaces.Discrete) \
            and observation.shape == self.observation_space.shape:
            observation = torch.unsqueeze(observation, 0)
        if not isinstance(self.action_space, spaces.Discrete) \
            and action.shape == self.action_space.shape:
            action = torch.unsqueeze(action, 0)

        q_sa = self.qf1(observation, action)

        return q_sa


    def compute_target(self,
                       ob: torch.tensor,
                       ac: torch.tensor,
                       rew: torch.tensor,
                       next_ob: torch.tensor,
                       done: torch.tensor):

        batch_size = ob.shape[0]

        with torch.no_grad():
            clipped_noise = (torch.randn_like(ac, device=self.device) * self.policy_noise).clamp(
                -self.noise_clip, self.noise_clip
            ) * self.target_actor.action_scale
            next_state_actions = (self.target_actor(next_ob) + clipped_noise).clamp(
                self.action_space.low[0], self.action_space.high[0]
            )
            qf1_next_target = self.qf1_target(next_ob, next_state_actions)
            qf2_next_target = self.qf2_target(next_ob, next_state_actions)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)
            target = rew.flatten() + (1 - done.flatten()) * self.gamma * (min_qf_next_target).view(-1)
            # print(f'{rew = }')
            # print(f'{done = }')
            # print(f'{next_ob = }')
            # print(f'{next_state_actions = }')
            # print(f'{min_qf_next_target = }')
            # print(f'{next_q_sa2 = }')
            # print(f'{self.alpha = }')
            # print(f'{next_state_log_pi = }')
            # print(f'{next_soft_q_sa = }')
            # print(f'{target = }')

        return target


    def compute_critic_loss(
            self,
            ob: torch.tensor,
            ac: torch.tensor,
            rew: torch.tensor,
            next_ob: torch.tensor,
            done: torch.tensor,
        ):

        batch_size = ob.shape[0]

        target = self.compute_target(ob, ac, rew, next_ob, done)
        qf1_a_values = self.qf1(ob, ac).view(-1)
        qf2_a_values = self.qf2(ob, ac).view(-1)
        qf1_loss = F.mse_loss(qf1_a_values, target)
        qf2_loss = F.mse_loss(qf2_a_values, target)
        qf_loss = qf1_loss + qf2_loss

        # print(f'{ob = }')
        # print(f'{ac = }')
        # print(f'{rew = }')
        # print(f'{next_ob = }')
        # # print(f'{next_ac = }')
        # print(f'{done = }')
        # print(f'{qf1_a_values = }')
        # # print(f'{V_soft_ = }')
        # print(f'{target = }')
        # # print(f'{loss = }')
        #
        # print(f'{qf_loss = }')
        # print('='*100)

        return qf_loss, target - qf1_a_values



    def compute_actor_loss(
            self,
            ob: torch.tensor,
        ):
        actor_loss = -self.qf1(ob, self.actor(ob)).mean()
        # print(f'{pi = }')
        # print(f'{log_pi = }')
        # print(f'{self.alpha = }')
        # print(f'{min_qf_pi = }')
        # print(f'{actor_loss = }')
        # print('+'*100)

        return actor_loss


    def update_target_network(self):
        '''
            theta' <- polyak * theta' * (1 - polyak) * theta
        '''
        for param, target_param in zip(self.actor.parameters(), self.target_actor.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.qf1.parameters(), self.qf1_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.qf2.parameters(), self.qf2_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)


    def update(self, use_all_samples: bool = False, update_policy: bool = True):
        if use_all_samples:
            batch = self.buffer.sample_uniformly()
            # print(batch)
        else:
            batch = self.buffer.sample(batch_size=self.batch_size)
        ob, ac, rew, ob2, done = \
            batch['ob'], batch['ac'], batch['rew'], batch['ob2'], batch['done']
        ob   = torch.as_tensor(ob, dtype=torch.float32, device=self.device)
        ac	 = torch.as_tensor(ac, dtype=torch.float32, device=self.device)
        ob2  = torch.as_tensor(ob2, dtype=torch.float32, device=self.device)
        rew	 = torch.as_tensor(rew, dtype=torch.float32, device=self.device).view(-1, 1)
        done = torch.as_tensor(done, dtype=torch.float32, device=self.device).view(-1, 1)

        if use_all_samples:
            batch_size = len(ob)

        DEBUGGING = False
        # DEBUGGING = True
        if DEBUGGING:
            check_param_norm([self.qf1, self.qf2])
        self.q_optimizer.zero_grad()
        critic_loss, error = self.compute_critic_loss(ob, ac, rew, ob2, done)
        critic_loss.backward()
        if DEBUGGING:
            check_gradient([self.qf1, self.qf2])
        self.q_optimizer.step()
        error = error.detach().abs().mean().item()
        critic_loss = critic_loss.detach().item()
        if DEBUGGING:
            check_param_norm([self.qf1, self.qf2])

        # print(f'{update_policy = }')
        if update_policy:
            if DEBUGGING:
                check_param_norm([self.actor])
            self.actor_optimizer.zero_grad()
            actor_loss = self.compute_actor_loss(ob)
            actor_loss.backward()
            if DEBUGGING:
                check_gradient([self.actor])
            self.actor_optimizer.step()
            actor_loss = actor_loss.detach().item()
            if DEBUGGING:
                check_param_norm([self.actor])

        self.update_target_network()
        # print(f'eloss = {entropy_loss}')
        # print(f'error = {error}')
        # print(f'closs = {critic_loss}')
        # print(f'aloss = {actor_loss}')

        return critic_loss, error


    def save(self, model_dir: Path, tag: str, verbose=False):

        def save_one_model(module, save_path):
            torch.save(module.state_dict(), save_path)
            if verbose:
                print(f'saved as {save_path}')

        save_one_model(self.actor, model_dir/f'actor_{tag}.pt')
        save_one_model(self.target_actor, model_dir/f'target_actor_{tag}.pt')
        save_one_model(self.qf1, model_dir/f'qf1_{tag}.pt')
        save_one_model(self.qf1_target, model_dir/f'qf1_target_{tag}.pt')
        save_one_model(self.qf2, model_dir/f'qf2_{tag}.pt')
        save_one_model(self.qf2_target, model_dir/f'qf2_target_{tag}.pt')


    def load(self, model_dir: Path, tag: str, verbose=False):

        def load_one_model(module, load_path):
            module.load_state_dict(torch.load(load_path))
            if verbose:
                print(f'loaded from {load_path}')

        load_one_model(self.actor, model_dir/f'actor_{tag}.pt')
        load_one_model(self.target_actor, model_dir/f'target_actor_{tag}.pt')
        load_one_model(self.qf1, model_dir/f'qf1_{tag}.pt')
        load_one_model(self.qf1_target, model_dir/f'qf1_target_{tag}.pt')
        load_one_model(self.qf2, model_dir/f'qf2_{tag}.pt')
        load_one_model(self.qf2_target, model_dir/f'qf2_target_{tag}.pt')


    def get_state_dict(self) -> dict:
        checkpoint = dict()
        checkpoint['buffer'] = self.buffer
        checkpoint['qf1_state_dict'] = self.qf1.state_dict()
        checkpoint['qf2_state_dict'] = self.qf2.state_dict()
        checkpoint['qf1_target_state_dict'] = self.qf1_target.state_dict()
        checkpoint['qf2_target_state_dict'] = self.qf2_target.state_dict()
        checkpoint['q_optimizer_state_dict'] = self.q_optimizer.state_dict()
        checkpoint['actor'] = self.actor.state_dict()
        checkpoint['target_actor'] = self.target_actor.state_dict()
        checkpoint['actor_optimizer'] = self.actor_optimizer.state_dict()
        return checkpoint


    def set_state_dict(self, checkpoint: dict):
        self.buffer = checkpoint['buffer']
        self.qf1.load_state_dict(checkpoint['qf1_state_dict'])
        self.qf2.load_state_dict(checkpoint['qf2_state_dict'])
        self.qf1_target.load_state_dict(checkpoint['qf1_target_state_dict'])
        self.qf2_target.load_state_dict(checkpoint['qf2_target_state_dict'])
        self.q_optimizer.load_state_dict(checkpoint['q_optimizer_state_dict'])
        self.actor.load_state_dict(checkpoint['actor'])
        self.target_actor.load_state_dict(checkpoint['target_actor'])
        self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
