import os
import torch
import torch.nn.functional as F
from copy import deepcopy

from .sac_mini import SAC_mini
from .utils import soft_update, save_model, tie_weights


class SAC_meta(SAC_mini):
    def __init__(self, CONFIG):
        super().__init__(CONFIG)
        """do not initialize optimizer in sac_mini"""
        self.meta_update_period = CONFIG.META_UPDATE_PERIOD


    def build_network(self,
                      CONFIG_ARCH,
                      build_optimizer=True,
                      verbose=True):
        super().build_network(CONFIG_ARCH, build_optimizer=False, verbose=verbose)

        # Set meta networks - deepcopy does not keep references, so critic and actor encoders are not tied
        self.meta_critic = deepcopy(self.critic)
        self.meta_actor = deepcopy(self.actor)
        self.meta_critic_target = deepcopy(self.critic)

        # Tie weights for encoder
        self.meta_actor.encoder.copy_conv_weights_from(self.meta_critic.encoder)

        # Tie language module
        if CONFIG_ARCH.LANG:
            for source_layer, layer in zip(
                self.meta_critic.mlp_lang.children(), self.meta_actor.mlp_lang.children(
                    )):  # children() works for both Sequential and nn.Module
                if isinstance(layer, torch.nn.Linear):
                    tie_weights(src=source_layer, trg=layer)

        # Initialize meta optimizer
        if build_optimizer:
            self.meta_critic_optimizer = torch.optim.Adam(self.meta_critic.parameters(), lr=self.CONFIG.LR_C_META)
            self.meta_actor_optimizer = torch.optim.Adam(self.meta_actor.parameters(), lr=self.CONFIG.LR_A_META)
            
            # Initialize optimizer for language module
            # if CONFIG_ARCH.LANG:
            #     self.lang_optimizer = torch.optim.Adam(self.meta_critic.mlp_lang.parameters(), lr=self.CONFIG.LR_LANG)


    def clone(self, optimizer_state=None):
        """Clone for adapting on one task. Right now re-initialize inner optimizer"""
        self.critic = deepcopy(self.meta_critic) # overwrites
        self.actor = deepcopy(self.meta_actor)
        # self.critic_target = deepcopy(self.meta_critic)  #!
        self.critic_target = deepcopy(self.meta_critic_target)  #!

        # Tie weights for encoder
        self.actor.encoder.copy_conv_weights_from(self.critic.encoder)

        # Inner optimizer - beta_1 = 0 as in the paper
        # self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.LR_C, betas=(0, 0.999))
        # self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.LR_A, betas=(0, 0.999))
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.LR_C)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.LR_A)
        if optimizer_state is not None:
            self.critic_optimizer.load_state_dict(optimizer_state[0])
            self.actor_optimizer.load_state_dict(optimizer_state[1])


    def get_optimizer_state(self):
        return deepcopy(self.critic_optimizer.state_dict()), \
                deepcopy(self.actor_optimizer.state_dict())


    def store_grad(self):
        """Store gradient (parameter difference) to meta after each task adaptation"""
        # Critic
        for w, w_t in zip(self.meta_critic.parameters(), self.critic.parameters()):
            if w.grad is None:
                w.grad = torch.zeros_like(w).to(self.meta_critic.device)
            w.grad.data.add_(w.data - w_t.data) # invert sign since gradient descent

        # Actor - do not update encoder or language module
        for (w_name, w), (w_t_name, w_t) in zip(self.meta_actor.named_parameters(), self.actor.named_parameters()):
            if 'conv' in w_name or 'mlp_lang' in w_name:
                continue
            if w.grad is None:
                w.grad = torch.zeros_like(w).to(self.meta_actor.device)
            w.grad.data.add_(w.data - w_t.data)


    def meta_update(self, num_task, timer=2):
        """Maintain externally internally"""

        for w in self.meta_critic.parameters():
            w.grad.data.div_(num_task)
        for w in self.meta_actor.parameters():
            w.grad.data.div_(num_task)
        self.meta_critic_optimizer.step()
        if timer % self.meta_update_period == 0:
            self.meta_actor_optimizer.step()

        # Clear grad in meta critic and actor
        for w in self.meta_critic.parameters():
            w.grad.data.zero_()
        for w in self.meta_actor.parameters():
            w.grad.data.zero_()

        # Soft update critic target
        soft_update(self.meta_critic_target, self.meta_critic, self.TAU)
