from abc import ABC, abstractmethod
import os
from copy import deepcopy
import numpy as np
import torch
from torch.optim import Adam
from torch.optim import lr_scheduler

from nn.lang_actor_critic import LangPiNetwork, LangTwinnedQNetwork
from nn.actor_critic import PiNetwork, TwinnedQNetwork
from .utils import soft_update, save_model, StepLRMargin, tie_weights



class SAC_Base(ABC):
    def __init__(self, CONFIG):
        self.CONFIG = CONFIG
        self.EVAL = CONFIG.EVAL
        self.device = CONFIG.DEVICE

        # == PARAM FOR TRAINING ==
        # Learning Rate
        self.LR_A_SCHEDULE = CONFIG.LR_A_SCHEDULE
        self.LR_C_SCHEDULE = CONFIG.LR_C_SCHEDULE
        if self.LR_A_SCHEDULE:
            self.LR_A_PERIOD = CONFIG.LR_A_PERIOD
            self.LR_A_DECAY = CONFIG.LR_A_DECAY
            self.LR_A_END = CONFIG.LR_A_END
        if self.LR_C_SCHEDULE:
            self.LR_C_PERIOD = CONFIG.LR_C_PERIOD
            self.LR_C_DECAY = CONFIG.LR_C_DECAY
            self.LR_C_END = CONFIG.LR_C_END
        self.LR_C = CONFIG.LR_C
        self.LR_A = CONFIG.LR_A

        # Discount factor
        self.GAMMA_SCHEDULE = CONFIG.GAMMA_SCHEDULE
        if self.GAMMA_SCHEDULE:
            self.GammaScheduler = StepLRMargin(initValue=CONFIG.GAMMA,
                                                period=CONFIG.GAMMA_PERIOD,
                                                decay=CONFIG.GAMMA_DECAY,
                                                endValue=CONFIG.GAMMA_END,
                                                goalValue=1.)
            self.GAMMA = self.GammaScheduler.get_variable()
        else:
            self.GAMMA = CONFIG.GAMMA

        # Target Network Update
        self.TAU = CONFIG.TAU

        # alpha-related hyper-parameters
        self.init_alpha = CONFIG.ALPHA
        self.LEARN_ALPHA = CONFIG.LEARN_ALPHA
        self.log_alpha = torch.tensor(np.log(self.init_alpha)).to(
            self.device)
        self.target_entropy = CONFIG.TARGET_ENTROPY
        if self.LEARN_ALPHA:
            self.log_alpha.requires_grad = True
            self.LR_Al = CONFIG.LR_Al
            self.LR_Al_SCHEDULE = CONFIG.LR_Al_SCHEDULE
            if self.LR_Al_SCHEDULE:
                self.LR_Al_PERIOD = CONFIG.LR_Al_PERIOD
                self.LR_Al_DECAY = CONFIG.LR_Al_DECAY
                self.LR_Al_END = CONFIG.LR_Al_END
            print("SAC with learnable alpha and target entropy = {:.1e}".
                    format(self.target_entropy))
        else:
            print("SAC with fixed alpha = {:.1e}".format(self.init_alpha))


    @property
    def alpha(self):
        return self.log_alpha.exp()


    @property
    @abstractmethod
    def has_latent(self):
        raise NotImplementedError


    @property
    @abstractmethod
    def latent_dist(self):
        raise NotImplementedError


    def parameters(self):
        return self.actor.parameters()


    def build_network(self,
                      CONFIG_ARCH,
                      verbose=True):
        if CONFIG_ARCH.LANG:
            pi_network = LangPiNetwork
            q_network = LangTwinnedQNetwork
        else:
            pi_network = PiNetwork
            q_network = TwinnedQNetwork

        self.actor = pi_network(
            input_n_channel=CONFIG_ARCH.OBS_CHANNEL,
            img_sz=[CONFIG_ARCH.IMG_H, CONFIG_ARCH.IMG_W],
            # latent_dim=CONFIG_ARCH.LATENT_DIM,
            action_dim=CONFIG_ARCH.ACTION_DIM,
            action_mag=1,
            mlp_dim=CONFIG_ARCH.MLP_DIM['actor'],
            append_dim=CONFIG_ARCH.APPEND_DIM['actor'],
            activation_type=CONFIG_ARCH.ACTIVATION['actor'],
            kernel_sz=CONFIG_ARCH.KERNEL_SIZE,
            stride=CONFIG_ARCH.STRIDE,
            padding=CONFIG_ARCH.PADDING,
            n_channel=CONFIG_ARCH.N_CHANNEL,
            use_sm=CONFIG_ARCH.USE_SM,
            use_ln=CONFIG_ARCH.USE_LN,
            use_film=CONFIG_ARCH.USE_FILM,
            dual_conv=CONFIG_ARCH.DUAL_CONV,
            lang_dim=CONFIG_ARCH.LANG_DIM,
            device=self.device,
            verbose=verbose)
        self.critic = q_network(
            input_n_channel=CONFIG_ARCH.OBS_CHANNEL,
            img_sz=[CONFIG_ARCH.IMG_H, CONFIG_ARCH.IMG_W],
            # latent_dim=CONFIG_ARCH.LATENT_DIM,
            mlp_dim=CONFIG_ARCH.MLP_DIM['critic'],
            action_dim=CONFIG_ARCH.ACTION_DIM,
            append_dim=CONFIG_ARCH.APPEND_DIM['critic'],
            activation_type=CONFIG_ARCH.ACTIVATION['critic'],
            kernel_sz=CONFIG_ARCH.KERNEL_SIZE,
            stride=CONFIG_ARCH.STRIDE,
            padding=CONFIG_ARCH.PADDING,
            n_channel=CONFIG_ARCH.N_CHANNEL,
            use_sm=CONFIG_ARCH.USE_SM,
            use_ln=CONFIG_ARCH.USE_LN,
            use_film=CONFIG_ARCH.USE_FILM,
            dual_conv=CONFIG_ARCH.DUAL_CONV,
            lang_dim=CONFIG_ARCH.LANG_DIM,
            device=self.device,
            verbose=verbose)

        # Load model if specified
        if CONFIG_ARCH.ACTOR_PATH is not None:
            self.actor.load_state_dict(
                torch.load(CONFIG_ARCH.ACTOR_PATH, map_location=self.device))
            print("--> Load actor wights from {}".format(CONFIG_ARCH.ACTOR_PATH))
        if CONFIG_ARCH.CRITIC_PATH is not None:
            self.critic.load_state_dict(
                torch.load(CONFIG_ARCH.CRITIC_PATH, map_location=self.device))
            print("--> Load critic wights from {}".format(CONFIG_ARCH.CRITIC_PATH))

        # Copy for critic targer
        self.critic_target = deepcopy(self.critic)

        # Tie weights for conv layers
        if CONFIG_ARCH.TIE_CONV:
            self.actor.encoder.copy_conv_weights_from(self.critic.encoder)

        # Tie weights for language module
        if CONFIG_ARCH.LANG:
            for source_layer, layer in zip(
                self.critic.mlp_lang.children(), self.actor.mlp_lang.children(
                    )):  # children() works for both Sequential and nn.Module
                if isinstance(layer, torch.nn.Linear):
                    tie_weights(src=source_layer, trg=layer)


    def build_optimizer(self):
        # print("Build basic optimizers.")
        self.critic_optimizer = Adam(self.critic.parameters(), lr=self.LR_C)
        self.actor_optimizer = Adam(self.actor.parameters(), lr=self.LR_A)

        if self.LR_C_SCHEDULE:
            self.critic_scheduler = lr_scheduler.StepLR(
                self.critic_optimizer,
                step_size=self.LR_C_PERIOD,
                gamma=self.LR_C_DECAY)
        if self.LR_A_SCHEDULE:
            self.actor_scheduler = lr_scheduler.StepLR(
                self.actor_optimizer,
                step_size=self.LR_A_PERIOD,
                gamma=self.LR_A_DECAY)

        if self.LEARN_ALPHA:
            self.log_alpha_optimizer = Adam([self.log_alpha], lr=self.LR_Al)
            if self.LR_Al_SCHEDULE:
                self.log_alpha_scheduler = lr_scheduler.StepLR(
                    self.log_alpha_optimizer,
                    step_size=self.LR_Al_PERIOD,
                    gamma=self.LR_Al_DECAY)


    # region: update functions
    def update_alpha_hyperParam(self):
        if self.LR_Al_SCHEDULE:
            lr = self.log_alpha_optimizer.state_dict()['param_groups'][0]['lr']
            if lr <= self.LR_Al_END:
                for param_group in self.log_alpha_optimizer.param_groups:
                    param_group['lr'] = self.LR_Al_END
            else:
                self.log_alpha_scheduler.step()


    def update_critic_hyperParam(self):
        if self.LR_C_SCHEDULE:
            lr = self.critic_optimizer.state_dict()['param_groups'][0]['lr']
            if lr <= self.LR_C_END:
                for param_group in self.critic_optimizer.param_groups:
                    param_group['lr'] = self.LR_C_END
            else:
                self.critic_scheduler.step()
        if self.GAMMA_SCHEDULE:
            self.GammaScheduler.step()
            self.GAMMA = self.GammaScheduler.get_variable()


    def update_actor_hyperParam(self):
        if self.LR_A_SCHEDULE:
            lr = self.actor_optimizer.state_dict()['param_groups'][0]['lr']
            if lr <= self.LR_A_END:
                for param_group in self.actor_optimizer.param_groups:
                    param_group['lr'] = self.LR_A_END
            else:
                self.actor_scheduler.step()


    def update_hyper_param(self):
        self.update_critic_hyperParam()
        self.update_actor_hyperParam()
        if self.LEARN_ALPHA:
            self.update_alpha_hyperParam()


    def update_target_networks(self):
        soft_update(self.critic_target, self.critic, self.TAU)


    @abstractmethod
    def update_actor(self, batch):
        raise NotImplementedError


    @abstractmethod
    def update_critic(self, batch):
        raise NotImplementedError


    @abstractmethod
    def update(self, batch, timer, update_period=2):
        raise NotImplementedError


    # utils
    @abstractmethod
    def value(self, obs, append):
        raise NotImplementedError


    def save(self, step, logs_path, max_model=None):
        path_c = os.path.join(logs_path, 'critic')
        path_a = os.path.join(logs_path, 'actor')
        save_model(self.critic, step, path_c, 'critic', max_model)
        save_model(self.actor, step, path_a, 'actor', max_model)


    def remove(self, step, logs_path):
        path_c = os.path.join(logs_path, 'critic',
                              'critic-{}.pth'.format(step))
        path_a = os.path.join(logs_path, 'actor', 'actor-{}.pth'.format(step))
        print("Remove", path_a)
        print("Remove", path_c)
        if os.path.exists(path_c):
            os.remove(path_c)
        if os.path.exists(path_a):
            os.remove(path_a)


    def get_model(self):
        return deepcopy(self.actor.state_dict()), \
               deepcopy(self.critic.state_dict()), \
               deepcopy(self.critic_target.state_dict())


    def restore_model(self, state_dicts):
        self.actor.load_state_dict(state_dicts[0])
        self.critic.load_state_dict(state_dicts[1])
        self.critic_target.load_state_dict(state_dicts[2])


    def get_optimizer(self):
        return deepcopy(self.actor_optimizer.state_dict()), \
               deepcopy(self.critic_optimizer.state_dict())


    def restore_optimizer(self, state_dicts):
        self.actor_optimizer.load_state_dict(state_dicts[0])
        self.critic_optimizer.load_state_dict(state_dicts[1])
