from audioop import reverse
from email import policy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions as pyd
import copy
import math
import torch.nn.functional as F
import utils
import hydra
import time
from collections import OrderedDict
import torch.autograd as autograd


def get_tensor_values(obs, actions, network=None):
    action_shape = actions.shape[0]
    obs_shape = obs.shape[0]
    num_repeat = int (action_shape / obs_shape)
        
    # obs : [bs, dim]
    # [bs, num_act, dim] -> [bs*num_act, dim]
    obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(obs.shape[0] * num_repeat, obs.shape[1])
    q1_preds, q2_preds = network(obs_temp, actions)
    q1_preds = q1_preds.view(obs.shape[0], num_repeat, 1)
    q2_preds = q2_preds.view(obs.shape[0], num_repeat, 1)
    return q1_preds, q2_preds

def get_policy_actions(obs, num_actions, network=None):
    
    # obs : [bs, dim]
    # [bs, num_act, dim] -> [bs*num_act, dim]
    obs_temp = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1])
    dist = network(obs_temp)    
    new_obs_actions = dist.rsample()
    new_obs_log_pi = dist.log_prob(new_obs_actions).sum(-1, keepdim=True)
        
    return new_obs_actions, new_obs_log_pi.view(obs.shape[0], num_actions, 1)
    

def disable_gradients(network):
    for param in network.parameters():
        param.requires_grad = False

class StateEncoder(nn.Module):
    def __init__(self, obs_shape, proj_dim, custom_encode_obs = None, spectral_norm = False):
        super().__init__()
        # proj_dim default is 128
        assert len(obs_shape) == 1
        self.repr_dim = 256
        assert proj_dim <= self.repr_dim
        self.spectral_norm = spectral_norm
        self.custom_encode_obs =custom_encode_obs
        if custom_encode_obs is not None:
            print('currently state encoder is use custom encode obs : ', custom_encode_obs)
            if custom_encode_obs =='sawyer':
                obs_dim = 3
            elif custom_encode_obs=='tabletop_manipulation':
                obs_dim = 2
            if self.spectral_norm:
                self.mlp_layer = nn.Sequential(nn.utils.spectral_norm(nn.Linear(obs_dim, self.repr_dim)),
                                                nn.ReLU())
            else:
                self.mlp_layer = nn.Sequential(nn.Linear(obs_dim, self.repr_dim),
                                                nn.ReLU())
        else:
            if self.spectral_norm:
                self.mlp_layer = nn.Sequential(nn.utils.spectral_norm(nn.Linear(obs_shape[-1], self.repr_dim)),
                                                nn.ReLU())
            else:
                self.mlp_layer = nn.Sequential(nn.Linear(obs_shape[-1], self.repr_dim),
                                                nn.ReLU())

        self.projector = nn.Linear(self.repr_dim, proj_dim)

        self.apply(utils.weight_init)

    def encode(self, obs):        
        if self.custom_encode_obs=='sawyer':
            obs = obs[..., 4:7] # only door position
        elif self.custom_encode_obs=='tabletop_manipulation':
            obs = obs[..., 2:4] # only object position

        return self.mlp_layer(obs)                

    def forward(self, obs):
        h = self.encode(obs)
        z = self.projector(h)
        return z

class IdentityEncoder(nn.Module):
    def __init__(self, obs_shape, project_for_state_input = False):
        super().__init__()
        # proj_dim default is 128
        assert len(obs_shape) == 1
        
        self.repr_dim = obs_shape[-1]
        # only for matching the outpacetype's dim
        self.project_for_state_input = project_for_state_input
        if project_for_state_input:
            self.projector = nn.Linear(self.repr_dim, self.repr_dim)

    def encode(self, obs):
        return obs

    def forward(self, obs):
        h = self.encode(obs)
        if self.project_for_state_input:
            z = self.projector(h)
        else:
            z = h
            
        return z

class Encoder(nn.Module):
    def __init__(self, obs_shape, proj_dim):
        super().__init__()

        assert len(obs_shape) == 3

        self.conv = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
                                  nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                  nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                  nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                  nn.ReLU())

        self.repr_dim = 32 * 35 * 35

        self.projector = nn.Linear(self.repr_dim, proj_dim)

        self.apply(utils.weight_init)

    def encode(self, obs):
        obs = obs / 255.
        h = self.conv(obs)
        h = h.view(h.shape[0], -1)
        return h

    def forward(self, obs):
        h = self.encode(obs)
        z = self.projector(h)
        return z


class Actor(nn.Module):
    def __init__(self, repr_dim, feature_dim, action_shape, hidden_dim,
                 hidden_depth, log_std_bounds):
        super().__init__()

        self.log_std_bounds = log_std_bounds
        self.pre_fc = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                    nn.LayerNorm(feature_dim))
        self.fc = utils.mlp(feature_dim, hidden_dim, 2 * action_shape[0],
                            hidden_depth)

        self.apply(utils.weight_init)

    def forward(self, obs):
        h = self.pre_fc(obs)
        mu, log_std = self.fc(h).chunk(2, dim=-1)

        # constrain log_std inside [log_std_min, log_std_max]
        log_std = torch.tanh(log_std)
        log_std_min, log_std_max = self.log_std_bounds
        log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std +
                                                                     1)
        std = log_std.exp()

        dist = utils.SquashedNormal(mu, std)
        return dist

class StateActor(nn.Module):
    def __init__(self, feature_dim, action_shape, hidden_dim,
                 hidden_depth, log_std_bounds):
        super().__init__()

        self.log_std_bounds = log_std_bounds
        # self.pre_fc = nn.Sequential(nn.Linear(repr_dim, feature_dim),
        #                             nn.LayerNorm(feature_dim))
        self.fc = utils.mlp(feature_dim, hidden_dim, 2 * action_shape[0],
                            hidden_depth)

        self.apply(utils.weight_init)

    def forward(self, obs):
        # h = self.pre_fc(obs)
        h = obs
        mu, log_std = self.fc(h).chunk(2, dim=-1)

        # constrain log_std inside [log_std_min, log_std_max]
        log_std = torch.tanh(log_std)
        log_std_min, log_std_max = self.log_std_bounds
        log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std +
                                                                     1)
        std = log_std.exp()

        dist = utils.SquashedNormal(mu, std)
        return dist

class StateCritic(nn.Module):
    def __init__(self, feature_dim, action_shape, hidden_dim,
                 hidden_depth, use_bvn = False, bvn_kwargs = None):
        super().__init__()

        # self.pre_fc = nn.Sequential(nn.Linear(repr_dim, feature_dim),
        #                             nn.LayerNorm(feature_dim))
        self.use_bvn = use_bvn
        if use_bvn:
            self.bvn_output_dim = bvn_kwargs.get('bvn_output_dim', 16)
            self.state_dim = bvn_kwargs.get('state_dim')
            self.goal_dim = bvn_kwargs.get('goal_dim')

            self.f_s_a_1 = utils.mlp(self.state_dim + action_shape[0], hidden_dim, self.bvn_output_dim, hidden_depth)
            self.phi_s_g_1 = utils.mlp(self.state_dim + self.goal_dim, hidden_dim, self.bvn_output_dim, hidden_depth)
            self.f_s_a_2 = utils.mlp(self.state_dim + action_shape[0], hidden_dim, self.bvn_output_dim, hidden_depth)
            self.phi_s_g_2 = utils.mlp(self.state_dim + self.goal_dim, hidden_dim, self.bvn_output_dim, hidden_depth)
            

        else:
            self.Q1 = utils.mlp(feature_dim + action_shape[0], hidden_dim, 1,
                                hidden_depth)
            self.Q2 = utils.mlp(feature_dim + action_shape[0], hidden_dim, 1,
                                hidden_depth)

        self.apply(utils.weight_init)

    def forward(self, obs, action):
        assert obs.size(0) == action.size(0)
        # h = self.pre_fc(obs)
        if self.use_bvn:
            # Assume [state, ag, dg]
            state = obs[..., :self.state_dim]
            goal = obs[..., -self.goal_dim:]
            q1 = torch.sum(self.f_s_a_1(torch.cat([state, action], dim=-1))*self.phi_s_g_1(torch.cat([state, goal], dim=-1)), dim = -1, keepdim=True) #[bs, 1]
            q2 = torch.sum(self.f_s_a_2(torch.cat([state, action], dim=-1))*self.phi_s_g_2(torch.cat([state, goal], dim=-1)), dim = -1, keepdim=True) #[bs, 1]        
        else:
            h = obs
            h_action = torch.cat([h, action], dim=-1)
            q1 = self.Q1(h_action)
            q2 = self.Q2(h_action)

        return q1, q2

class StateActorTD3(nn.Module):
    def __init__(self, feature_dim, action_shape, hidden_dim,
                 hidden_depth):
        super().__init__()

        self.fc = utils.mlp(feature_dim, hidden_dim, action_shape[0],
                            hidden_depth)

        self.apply(utils.weight_init)

    def forward(self, obs):
        # h = self.pre_fc(obs)
        h = obs
        mu = self.fc(h)
        return torch.tanh(mu)

class StateCriticTD3(nn.Module):
    def __init__(self, feature_dim, action_shape, hidden_dim,
                 hidden_depth):
        super().__init__()

        self.Q1 = utils.mlp(feature_dim + action_shape[0], hidden_dim, 1,
                            hidden_depth)
        self.Q2 = utils.mlp(feature_dim + action_shape[0], hidden_dim, 1,
                            hidden_depth)

        self.apply(utils.weight_init)

    def forward(self, obs, action):
        assert obs.size(0) == action.size(0)    
        h = obs
        h_action = torch.cat([h, action], dim=-1)
        q1 = self.Q1(h_action)
        q2 = self.Q2(h_action)

        return q1, q2


class StateVf(nn.Module):
    def __init__(self, feature_dim, hidden_dim,
                 hidden_depth):
        super().__init__()

        # self.pre_fc = nn.Sequential(nn.Linear(repr_dim, feature_dim),
        #                             nn.LayerNorm(feature_dim))
        self.V = utils.mlp(feature_dim, hidden_dim, 1,
                            hidden_depth)
        
        self.apply(utils.weight_init)

    def forward(self, obs):        
        # h = self.pre_fc(obs)         
        v = self.V(obs)
        return v

class StateCriticEnsemble(nn.Module):
    def __init__(self, feature_dim, action_shape, hidden_dim,
                 hidden_depth, n_ensemble, rpf =False):
        super().__init__()
        self.n_ensemble = n_ensemble
        # self.pre_fc = nn.Sequential(nn.Linear(repr_dim, feature_dim),
        #                             nn.LayerNorm(feature_dim))
        # self.Q_ensemble = []
        # # for i in range(n_ensemble):
        #     self.q = utils.mlp(feature_dim + action_shape[0], hidden_dim, 1, hidden_depth)
            # self.Q_ensemble.append(q)
        self.Q_ensemble = nn.ModuleList([utils.mlp(feature_dim + action_shape[0], hidden_dim, 1,
                            hidden_depth) for i in range(n_ensemble)])
        self.rpf = rpf
        if rpf:
            self.Q_ensemble_rpf = nn.ModuleList([utils.mlp(feature_dim + action_shape[0], hidden_dim, 1,
                            hidden_depth) for i in range(n_ensemble)])
            for i in range(n_ensemble):
                disable_gradients(self.Q_ensemble_rpf[i])

                            
        self.apply(utils.weight_init)

    def forward(self, obs, action):
        assert obs.size(0) == action.size(0)
        # h = self.pre_fc(obs)
        
        h = obs
        h_action = torch.cat([h, action], dim=-1)
        if self.rpf:
            return [qf(h_action)+qf_rpf(h_action) for qf, qf_rpf in zip(self.Q_ensemble, self.Q_ensemble_rpf)]
        else:
            return [qf(h_action) for qf in self.Q_ensemble]
    
    def std(self, obs, action):
        q = torch.stack(self.forward(obs, action), dim = 1)  # [bs, n_ensemble, 1]
        return torch.std(q, dim = 1, keepdim=False) #[bs, 1]

class StateVfEnsemble(nn.Module):
    def __init__(self, feature_dim, hidden_dim,
                 hidden_depth, n_ensemble, rpf = False):
        super().__init__()
        self.n_ensemble = n_ensemble
        # self.pre_fc = nn.Sequential(nn.Linear(repr_dim, feature_dim),
        #                             nn.LayerNorm(feature_dim))
        self.V_ensemble = nn.ModuleList([utils.mlp(feature_dim, hidden_dim, 1,
                            hidden_depth) for i in range(n_ensemble)])
        self.rpf = rpf
        if rpf:
            self.V_ensemble_rpf = nn.ModuleList([utils.mlp(feature_dim, hidden_dim, 1,
                            hidden_depth) for i in range(n_ensemble)])
            for i in range(n_ensemble):
                disable_gradients(self.V_ensemble_rpf[i])
        
        # self.V = utils.mlp(feature_dim, hidden_dim, 1,
        #                     hidden_depth)
        
        self.apply(utils.weight_init)

    def forward(self, obs):        
        # h = self.pre_fc(obs)         
        # v = self.V(obs)
        if self.rpf:
            return [vf(obs)+vf_rpf(obs) for vf, vf_rpf in zip(self.V_ensemble, self.V_ensemble_rpf)]
        else:
            return [vf(obs) for vf in self.V_ensemble]
        
        

    def std(self, obs):
        v = torch.stack(self.forward(obs), dim =1)  #[bs, n_ensemble, 1]
        return torch.std(v, dim = 1, keepdim=False) #[bs, 1]


class Critic(nn.Module):
    def __init__(self, repr_dim, feature_dim, action_shape, hidden_dim,
                 hidden_depth):
        super().__init__()

        self.pre_fc = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                    nn.LayerNorm(feature_dim))
        self.Q1 = utils.mlp(feature_dim + action_shape[0], hidden_dim, 1,
                            hidden_depth)
        self.Q2 = utils.mlp(feature_dim + action_shape[0], hidden_dim, 1,
                            hidden_depth)

        self.apply(utils.weight_init)

    def forward(self, obs, action):
        assert obs.size(0) == action.size(0)
        h = self.pre_fc(obs)
        h_action = torch.cat([h, action], dim=-1)
        q1 = self.Q1(h_action)
        q2 = self.Q2(h_action)

        return q1, q2
