import torch, pdb
import torch.nn as nn
from rl_algorithm.ddqn.config import discount, batch_size, device
import numpy as np

# Function from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py
def init_params(m):
    classname = m.__class__.__name__
    if classname.find("Linear") != -1:
        m.weight.data.normal_(0, 1)
        m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1,
                                                                 keepdim=True))
        if m.bias is not None:
            m.bias.data.fill_(0)


class DDQN(nn.Module):
    def __init__(self, obs_space, action_space, is_init, is_use_mission):
        super().__init__()
        # Define image embedding
        n = obs_space["image"][0]
        m = obs_space["image"][1]
        z = obs_space["image"][2]
        shape = n*m*z
        self.image_conv = nn.Sequential(
            nn.Linear(shape, shape//2),
            nn.ReLU(),
            nn.Linear(shape//2, shape//4),
            nn.ReLU(),
            nn.Linear(shape//4, 64),
            nn.ReLU()
        )
        self.embedding_size = 64
        self.is_use_mission = is_use_mission

        if is_use_mission:
            self.word_embedding_size = 32
            self.word_embedding = nn.Embedding(obs_space["text"],
                                            self.word_embedding_size)
            self.text_embedding_size = 128
            self.text_rnn = nn.GRU(self.word_embedding_size,
                                self.text_embedding_size, batch_first=True)
            self.embedding_size += self.text_embedding_size
        
        # Define  model
        self.actor = nn.Sequential(
            nn.Linear(self.embedding_size, 64),
            nn.Tanh(),
            nn.Linear(64, action_space.n)
        )

        # Initialize parameters correctly
        if is_init:
            self.apply(init_params)
            
    def init_weights(self, m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
            nn.init.constant_(m.bias, 0.0)
    
    def reinit_weight(self,ver=0):
        if ver == 0:
            self.actor[0].apply(self.init_weights)
            self.actor[2].apply(self.init_weights)
        elif ver == 1:
            self.actor[2].apply(self.init_weights)
        elif ver == 2:
            self.actor[0].apply(self.init_weights)
            self.actor[2].apply(self.init_weights)
            self.image_conv[0].apply(self.init_weights)
            self.image_conv[2].apply(self.init_weights)
            self.image_conv[4].apply(self.init_weights)
    
    
    def forward(self, obs):
        x = obs.image.transpose(1, 3).transpose(2, 3)
        x = x.reshape(x.shape[0], -1)
        x = self.image_conv(x)

        if self.is_use_mission:
            embed_text = self._get_embed_text(obs.text)
            embedding = torch.cat((x, embed_text), dim=1)
        else:
            embedding = x

        x = self.actor(embedding)

        return x

    def _get_embed_text(self, text):
        _, hidden = self.text_rnn(self.word_embedding(text))
        return hidden[-1]

    @classmethod
    def train_model(
        cls, online_net, target_net, optimizer, collected_experience, is_rnd
    ):
        obs = collected_experience["obs"]
        new_obs = collected_experience["new_obs"]
        actions = collected_experience["actions"]
        rewards = collected_experience["rewards"]
        dones = collected_experience["dones"]
        indices = np.arange(batch_size)

        rewards = torch.tensor(rewards, device=device)

        Q_policy = online_net(obs)[indices, actions]
        max_actions = target_net(new_obs).max(dim=1)[0]

        # Update Q-Table
        Q_target = rewards + discount * max_actions * torch.tensor(dones, device=device)

        # compute loss
        loss = nn.functional.smooth_l1_loss(Q_policy, Q_target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return loss