import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.rnn_state_encoder import RNNStateEncoder
from custom_habitat_baselines.common.utils import CategoricalNet

from model.PCL.resnet_pcl import resnet18

class CriticHead(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.fc = nn.Linear(input_size, 1)
        nn.init.orthogonal_(self.fc.weight)
        nn.init.constant_(self.fc.bias, 0)

    def forward(self, x):
        return self.fc(x)

class CNNLSTMPolicy(nn.Module):
    def __init__(
            self,
            observation_space, # a SpaceDict instace. See line 35 in train_bc.py
            action_space,
            goal_sensor_uuid="pointgoal_with_gps_compass",
            hidden_size=512,
            num_recurrent_layers=2,
            rnn_type="LSTM",
            resnet_baseplanes=32,
            backbone="resnet50",
            normalize_visual_inputs=True,
            cfg=None
    ):
        super().__init__()
        self.net = CNNLSTMNet(
            observation_space=observation_space,
            action_space=action_space,
            goal_sensor_uuid=goal_sensor_uuid,
            hidden_size=hidden_size,
            num_recurrent_layers=num_recurrent_layers,
            rnn_type=rnn_type,
            backbone=backbone,
            resnet_baseplanes=resnet_baseplanes,
            normalize_visual_inputs=normalize_visual_inputs,
            cfg=cfg
        )
        self.config = cfg
        self.dim_actions = action_space.n # action_space = Discrete(config.ACTION_DIM)

        self.action_distribution = CategoricalNet(
            self.net.output_size, self.dim_actions
        )
        self.critic = CriticHead(self.net.output_size) # a single layer FC

    def act(
            self,
            observations, # obs are generated by calling env_wrapper.step() in line 59, bc_trainer.py
            rnn_hidden_states,
            prev_actions,
            masks,
            env_global_node,
            deterministic=False,
            return_features=False,
            mask_stop=False
    ):
    # observations['panoramic_rgb']: 64 x 252 x 3, observations['panoramic_depth']:  64 x 252 x 1, observations['target_goal']: 64 x 252 x 4
    # env_global_node: b x 1 x 512

    # features(xt): p(at|xt) = σ(FC(xt)) Size: num_processes x f_dim (512)
    
        features, rnn_hidden_states, preds = self.net(
            observations, rnn_hidden_states, prev_actions, masks, env_global_node, return_features=return_features
        )

        distribution, x = self.action_distribution(features)
        value = self.critic(features) # uses a FC layer to map features to a scalar value of size num_processes x 1
        if deterministic:
            action = distribution.mode()
        else:
            action = distribution.sample()

        action_log_probs = distribution.log_probs(action)
        
        # The shape of the output should be B * N * (shapes)
        # NOTE: change distribution_entropy to x
        return value, action, action_log_probs, rnn_hidden_states, None, x, preds, None

    def get_value(self, observations, rnn_hidden_states, env_global_node, prev_actions, masks):
        """
        get the value of the current state which is represented by an observation
        """
        # features is the logits of action candidates
        features, *_ = self.net(
            observations, rnn_hidden_states, prev_actions, masks, env_global_node, disable_forgetting=True
        )
        value = self.critic(features)
        return value

    def evaluate_actions(
            self, observations, rnn_hidden_states, env_global_node, prev_actions, masks, action, return_features=False
    ):
        features, rnn_hidden_states, preds = self.net(
            observations, rnn_hidden_states, prev_actions, masks, env_global_node, return_features=return_features, disable_forgetting=True
        )
        distribution, x = self.action_distribution(features)
        value = self.critic(features)

        action_log_probs = distribution.log_probs(action)
        distribution_entropy = distribution.entropy().mean()

        return value, action_log_probs, distribution_entropy, preds[0], preds[1], None, rnn_hidden_states, env_global_node, x

    def get_memory_span(self):
        return self.net.get_memory_span()

    def get_forget_idxs(self):
        return self.net.perception_unit.forget_idxs

class CNNLSTMNet(nn.Module):
    def __init__(
            self,
            observation_space,
            action_space,
            goal_sensor_uuid,
            hidden_size,
            num_recurrent_layers,
            rnn_type,
            backbone,
            resnet_baseplanes,
            normalize_visual_inputs,
            cfg
    ):
        super().__init__()
        self.goal_sensor_uuid = goal_sensor_uuid
        self.prev_action_embedding = nn.Embedding(action_space.n + 1, 32)
        self._n_prev_action = 32
        self.feature_dim = cfg.features.visual_feature_dim
        self.memory_dim = cfg.memory.embedding_size
        self.num_category = 50
        self._n_input_goal = 0
        self.device = 'cuda:' + str(cfg.TORCH_GPU_ID) if torch.cuda.device_count() > 0 else 'cpu'
        self._hidden_size = hidden_size

        rnn_input_size = self._n_input_goal + self._n_prev_action

        self.B = cfg.NUM_PROCESSES

        # resnet50 = torchvision.models.resnet50(pretrained=True)
        # self.rgb_encoder = tx.Extractor(resnet50, ["AvgPool"])
        # self.depth_encoder = VlnResnetDepthEncoder(
        #     observation_space,
        #     output_size=128, # 128
        #     checkpoint="../VLN-CE/data/ddppo-models/gibson-4plus-mp3d-train-val-test-resnet50.pth",
        #     backbone="resnet50",
        #     spatial_output=True,
        # )
        self.rgbd_encoder = resnet18(num_classes=self.feature_dim)
        dim_mlp = self.rgbd_encoder.fc.weight.shape[1] # 512
        self.rgbd_encoder.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.rgbd_encoder.fc)
        ckpt_pth = os.path.join('model/PCL', 'PCL_encoder.pth')
        ckpt = torch.load(ckpt_pth, map_location='cpu')
        self.rgbd_encoder.load_state_dict(ckpt)
        self.rgbd_encoder.eval()

        for p in self.rgbd_encoder.parameters():
            p.requires_grad = False

        self.rgbd_encoder.to(self.device) # torchvision ResNet18
        
        # self.reduce_rgb = nn.Sequential(
        #     nn.Linear(2048, 2048),
        #     nn.ReLU(True)
        # )

        f_dim = cfg.features.visual_feature_dim

        self.reduce_obs_seq = nn.Sequential(
            nn.Linear(4*f_dim, f_dim),
            nn.ReLU(True)
        )

        self.reduce_obs_goal = nn.Sequential(
            nn.Linear(2*f_dim, hidden_size),
            nn.ReLU(True)
        )

        #self.perception_unit = Perception(cfg)

        # visual_feature_dim and hidden_size are both default to 512
        

        self.pred_aux1 = nn.Sequential(nn.Linear(f_dim, f_dim),
                                nn.ReLU(True),
                                nn.Linear(f_dim, 1))
        self.pred_aux2 = nn.Sequential(nn.Linear(hidden_size, hidden_size),
                                    nn.ReLU(True),
                                    nn.Linear(hidden_size, 1))

        self.state_encoder = RNNStateEncoder(
            self._hidden_size + rnn_input_size,
            self._hidden_size,
            rnn_type=rnn_type,
            num_layers=num_recurrent_layers,
        )
        self.train()

        self.calc_params()

    def train(self, mode=True):
        #self.reduce_rgb.train(mode)
        self.reduce_obs_seq.train(mode)
        self.reduce_obs_goal.train(mode)
        #self.perception_unit.train(mode)
        self.pred_aux1.train(mode)
        self.pred_aux2.train(mode)
        self.state_encoder.train(mode)

    def calc_params(self):
        s = "- rgbd encoder: {}\n".format(sum(p.numel() for p in self.rgbd_encoder.parameters()))
        s += "- reduce_obs_seq encoder: {}\n".format(sum(p.numel() for p in self.reduce_obs_seq.parameters()))
        s += "- reduce_obs_goal encoder: {}\n".format(sum(p.numel() for p in self.reduce_obs_goal.parameters()))
        s += "- state_encoder: {}\n".format(sum(p.numel() for p in self.state_encoder.parameters()))
        s += "- aux tasks: {}\n".format(sum(p.numel() for p in self.pred_aux1.parameters()) + sum(p.numel() for p in self.pred_aux2.parameters()))

        print(s)
    
    @property
    def output_size(self):
        return self._hidden_size

    @property
    def is_blind(self):
        return False

    @property
    def num_recurrent_layers(self):
        return self.state_encoder.num_recurrent_layers

    def embed_obs_batch(self, obs_batch, b, step):
        # obs_batch contains:
        # ('panoramic_rgb_history', torch.Size([257, 4, 64, 252, 3])),
        # ('panoramic_depth_history', torch.Size([257, 4, 64, 252, 1])),
        # ('gps_history', torch.Size([257, 4, 2])),
        # ('compass_history', torch.Size([257, 4, 1])),
        # ('prev_action_history', torch.Size([257, 4, 1]))

        global_memory_idxs = obs_batch['global_memory'][step].long().to(self.device) # num_memory

        # print(obs_batch['panoramic_rgb_history'].shape, obs_batch['panoramic_depth_history'].shape)
        # print('idxs', step, ':', global_memory_idxs.cpu().detach())
        rgb_tensor = obs_batch['panoramic_rgb_history'][global_memory_idxs, b].permute(0,3,1,2) / 255.0 # num_memory x 3 x 64 x 252
        depth_tensor = obs_batch['panoramic_depth_history'][global_memory_idxs, b].permute(0,3,1,2) # num_memory x 1 x 64 x 252

        rgbd_seq_feat = self.rgbd_encoder(torch.cat([rgb_tensor, depth_tensor], dim=1))

        rgbd_feat = self.reduce_obs_seq(rgbd_seq_feat.view(1,-1))
        
        obs_emb_batch = self.reduce_obs_goal(torch.cat([rgbd_feat, obs_batch['goal_embedding'][step:step+1]], dim=1)) # B x 512
        
        return rgbd_feat, obs_emb_batch

    def forward(self, observations, rnn_hidden_states, prev_actions, masks, env_global_node, mode='', return_features=False, disable_forgetting=False):
        # prev_actions: B x 1 (float)
        #print(self.prev_action_embedding.weight.shape, prev_actions)
        prev_actions = self.prev_action_embedding(
            ((prev_actions.float() + 1) * masks).long().squeeze(-1)
        )

        global_memory, rgbd_feats = [], []
        #print('prev_actions',prev_actions.shape)
        
        num_samples = prev_actions.shape[0] # 256 * 4
        num_step_per_batch = num_samples // self.B # 256

        # print('[policy forward]',prev_actions.shape, self.B, num_step_per_batch)
        # input(observations['global_memory'])
        for b in range(self.B):
            for step in range(num_step_per_batch):
                rgbd_feat, obs_emb_batch = self.embed_obs_batch(obs_batch=observations, b=b, step=b*num_step_per_batch + step)
                curr_embedding = obs_emb_batch[-1:]

                rgbd_feats.append(rgbd_feat)
                global_memory.append(obs_emb_batch)

        rgbd_feats = torch.cat(rgbd_feats, dim=0)
        global_memory = torch.cat(global_memory, dim=0) # B x max_memory_num x 512
        curr_embedding = global_memory[:,-1:] # B x 1 x 512

        # curr_context: B x 512
        # goal_context: B x 512
        # new_env_global_node: B x 1 x 512
        #print(env_global_node[0:4,0,0:10])
        observations['global_memory'] = global_memory # B x 512
        observations['curr_embedding'] = curr_embedding

        # curr_context = self.perception_unit(observations, env_global_node) # B x 256
         # B x memory_size, True denotes the agent recorded an observation at that time step
        
        #contexts = torch.cat((curr_context, observations['goal_embedding']), -1)

        #feats = self.visual_fc(torch.cat((contexts, curr_embedding.squeeze(1)), 1))
        pred1 = self.pred_aux1(rgbd_feats)
        pred2 = self.pred_aux2(global_memory)

        #print(new_env_global_node[0:4,0,0:10])

        x = [global_memory, prev_actions]

        x = torch.cat(x, dim=1)
        x, rnn_hidden_states = self.state_encoder(x, rnn_hidden_states, masks)

        # x is used to generate the action prob distribution
        return x, rnn_hidden_states, (pred1, pred2) # ffeatures contains att scores of GATv2 if required; otherwise ffeatures is None
