from torch.autograd import Variable
from torch.optim import Adam,RMSprop,AdamW
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

 
class BCNetwork_Vec(nn.Module):
    '''
    Behavioral Clone Model Structure: 学动作概率
    '''
    def __init__(self, input_dim, out_dim, hidden_dim=64, activation=F.relu):
        super(BCNetwork_Vec, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, out_dim)
        self.activation = activation
        
    def forward(self, X):
        h1 = self.activation(self.fc1(X))
        h2 = self.activation(self.fc2(h1))
        out = self.fc3(h2)
        
        return F.softmax(out, dim=1)

# Categorical
class FixedCategorical(torch.distributions.Categorical):
    def sample(self):
        return super().sample().unsqueeze(-1)

    def log_probs(self, actions):
        return (
            super()
            .log_prob(actions.squeeze(-1))
            .view(actions.size(0), -1)
            .sum(-1)
            .unsqueeze(-1)
        )

    def mode(self):
        return self.probs.argmax(dim=-1, keepdim=True)

def init(module, weight_init, gain=1):
    weight_init(module.weight.data, gain=gain)
    return module

class BCNetWork_Mat(nn.Module):
    def __init__(self, map_name, hidden_dim, cnn_layers_params=None, meta_dim=0):
        super(BCNetWork_Mat, self).__init__()
        self.hidden_size = hidden_dim
        def init_linear(m): 
            return init(m, nn.init.orthogonal_, 0.01)
        gain = nn.init.calculate_gain('relu')
        def init_cnn(m):
            return init(m, nn.init.xavier_uniform_, gain=gain)
        if map_name == 'random3':
            self.conv1 = init_cnn(nn.Conv2d(20+meta_dim, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))
            flatten_num = 32 * 8 * 5
        if map_name == 'mo1':
            self.conv1 = init_cnn(nn.Conv2d(26+meta_dim, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))
            flatten_num = 32 * 5 * 5
        if map_name == 'unident_s':
            self.conv1 = init_cnn(nn.Conv2d(20+meta_dim, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))
            flatten_num = 32 * 9 * 5
        if map_name == 'distant_tomato':
            self.conv1 = init_cnn(nn.Conv2d(26+meta_dim, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))
            flatten_num = 32 * 7 * 5
        if map_name == 'soup_coordination':
            self.conv1 = init_cnn(nn.Conv2d(26+meta_dim, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))
            flatten_num = 32 * 11 * 5  
            
        self.conv2 = init_cnn(nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))
        self.conv3 = init_cnn(nn.Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))
        
        self.fc1 = init_cnn(nn.Linear(flatten_num, self.hidden_size))
        self.fc2 = init_cnn(nn.Linear(self.hidden_size, self.hidden_size))
        self.fc3 = init_cnn(nn.Linear(self.hidden_size, self.hidden_size))
        self.linear = init_linear(nn.Linear(self.hidden_size, 6))
        
    def _build_cnn_input(self, obs):
        cnn_input = []
        if obs.ndim == 3:
            obs = torch.unsqueeze(obs,0)
        cnn_input.append(obs.permute(0, 3, 1, 2) / 255.0)

        cnn_input = torch.cat(cnn_input, dim=1)
        return cnn_input
        
    def forward(self, inputs,available_actions=None, deterministic=False):
        
        cnn_input = self._build_cnn_input(inputs)
        conv1 = F.relu(self.conv1(cnn_input))
        conv2 = F.relu(self.conv2(conv1))
        conv3 = F.relu(self.conv3(conv2))
        flatten = conv3.reshape(conv3.size(0), -1)
        fc1 = F.relu(self.fc1(flatten))
        fc2 = F.relu(self.fc2(fc1))
        fc3 = F.relu(self.fc3(fc2))
        x = self.linear(fc3)
        
        if available_actions is not None:
            x[available_actions == 0] = -1e10
        #action_logits = FixedCategorical(logits=x)    
        #action = action_logits.mode() if deterministic else action_logits.sample() 
        #action = action_logits.mode() if deterministic else action_logits
        return F.softmax(x, dim=1)

class BCAgent(nn.Module):
    '''
    General class for BC agents 
    '''
    def __init__(self, params):
        super(BCAgent, self).__init__()
        
        self.lr = params['lr']
        self.gamma = params['gamma']
        
        self.obs_dim = params['obs_dim']
        self.action_dim = params['action_dim']
        self.hidden_dim = params['hidden_dim']
        self.map_name = params['env'] 

        self.policy = BCNetWork_Mat(map_name=self.map_name, hidden_dim=self.hidden_dim, meta_dim=params.get('meta_dim', 0)).cuda()
        
        self.policy_optimizer = RMSprop(self.policy.parameters(), lr=self.lr, eps=1e-8)
    
    def act(self, obs,deterministic=False):
        if 'torch' not in str(obs.dtype):
            obs = torch.Tensor(obs).cuda()    
        if obs.dim() == 1:
            obs = obs.unsqueeze(dim=0)

        if deterministic:
            action = self.policy(obs)
            #action = action.argmax(dim=-1, keepdim=True)
            # 参考：https://hrl.boyuai.com/chapter/3/%E6%A8%A1%E4%BB%BF%E5%AD%A6%E4%B9%A0
            action_dist = torch.distributions.Categorical(action)
            action = action_dist.sample()
        else:
            action = self.policy(obs)

        return action.item()
    
    def get_params(self):
        return {'policy': self.policy.state_dict(),
                'policy_optimizer': self.policy_optimizer.state_dict(),
                }        
    

        
        
        