import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
from image_encoder import MobileNetEncoder
from PIL import Image
from transformers import ViTImageProcessor, ViTMAEModel
from transformers import SamModel, SamProcessor
from transformers import AutoImageProcessor, ResNetModel, ResNetForImageClassification
from transformers import AutoImageProcessor, ViTForImageClassification
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
from torchvision import transforms
import requests

class IGNet(nn.Module):
    def __init__(self, configs) -> None:
        super().__init__()
        
        # Design of PM-Net:
        # (1): image encoder
        # (2): position encoder, and input type encoder
        # (3): transformer model
        # (4): output heads: positions, angles, actions, navigation path
        
        self.configs = configs
        self.contrast_embedding = configs['contrast_embedding'] if 'contrast_embedding' in configs else False
        self.contrast_method = configs['contrast_method'] if 'contrast_method' in configs else 'subtract'
        self.len_trajectory_pred = configs['len_trajectory_pred']
        self.action_num = configs['action_num']
        self.num_transformer_layers = configs['num_transformer_layers']
        self.disable_auxiliary = configs.get('disable_auxiliary', 'none')
        self.hidden_size = 768
        self.location_types = [
            'egocentric', 'map', 'query', 'past'
        ]
        self.huggingface_image_processor = True
        # Build image encoder
        # We will first use pretrained MAE as image encoder. 
        if configs['encoder'] == 'MobileNet':
            mobilenet = MobileNetEncoder(num_images=1)
            # self.image_processor = MobileNet_V2_Weights.IMAGENET1K_V2.transforms()
            self.image_encoder = mobilenet.features
            self.huggingface_image_processor = False
            self.image_processor = [
                transforms.ToTensor(), 
                transforms.Resize(
                    (configs['image_size'][1], configs['image_size'][0])
                ), 
                transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                    std=[0.229, 0.224, 0.225])
            ]
            self.image_processor = transforms.Compose(self.image_processor)
        elif configs['encoder'] == 'MobileNet-pretrained':
            mobilenet = mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V2)
            self.image_processor = MobileNet_V2_Weights.IMAGENET1K_V2.transforms()
            # self.image_processor = [
            #     transforms.ToTensor(), 
            #     transforms.Resize(
            #         (configs['image_size'][1], configs['image_size'][0])
            #     ), 
            #     transforms.Normalize(mean=[0.485, 0.456, 0.406], 
            #                         std=[0.229, 0.224, 0.225])
            # ]
            # self.image_processor = transforms.Compose(self.image_processor)
            self.image_encoder = mobilenet.features
            self.last_fc_layer = nn.Linear(1280, self.hidden_size)
            self.huggingface_image_processor = False
        elif configs['encoder'] == 'ViT-classification':
            self.image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
            self.image_encoder = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
        elif configs['encoder'] == 'MAE':
            self.image_processor = ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
            self.image_encoder = ViTMAEModel.from_pretrained("facebook/vit-mae-base")
        elif configs['encoder'] == 'SAM':
            self.image_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
            self.image_encoder = SamModel.from_pretrained("facebook/sam-vit-base")
        elif configs['encoder'] == 'resnet-50':
            self.image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
            self.image_encoder = ResNetModel.from_pretrained("microsoft/resnet-50")
        else:
            raise ValueError(f'The image encoder {configs["encoder"]} is not supported')
        self.flatten = nn.Flatten()
        
        if configs['freeze_image_encoder']:
            for param in self.image_encoder.parameters():
                param.requires_grad = False
            print('Freeze image encoder')
        else:
            print('Not freeze image encoder')
        
        # Build location embedding matrix
        self.location_type_embeds = nn.Embedding(
            len(self.location_types),
            self.hidden_size
        )
        # Build position and orientation encoder
        self.position3d_encoder = PositionEncoder(self.hidden_size)
        self.yaw_encoder = PositionEncoder(self.hidden_size)
        
        # Build embedding normalization layer
        self.embed_layer_norm = nn.LayerNorm(self.hidden_size)
        
        # Build Transformer model
        transformer_layer = nn.TransformerEncoderLayer(
            nhead=8, d_model=self.hidden_size
        )
        
        self.transformer = nn.TransformerEncoder(
            transformer_layer, num_layers=self.num_transformer_layers
        )
        
        if self.contrast_embedding:
            if self.contrast_method == 'subtract':
                self.contrast_fc = nn.Linear(self.hidden_size, self.hidden_size)
            elif self.contrast_method == 'concat':
                self.contrast_fc = nn.Linear(self.hidden_size * 2, self.hidden_size)
            else:
                raise NotImplementedError
        
        # Build decoders 
        # (ego positions, map positions, navigation path, navigation actions)
        self.ego_decoder = PositionDecoder(self.hidden_size)
        self.map_decoder = PositionDecoder(self.hidden_size)
        self.local_path_decoder = PathDecoder(
            self.hidden_size, self.len_trajectory_pred
        )
        self.global_path_decoder = PathDecoder(
            self.hidden_size, self.len_trajectory_pred
        )
        self.action_predictor = nn.Linear(
            self.hidden_size, self.len_trajectory_pred * self.action_num
        )
        self.dist_predictor = nn.Linear(
            self.hidden_size, 1
        )
        
        # Loss function
        # (1) Egocentric position predictions of queried locations
        # (2) Map position predictions of queried and egocentric locations
        # (optional) Egocentric and Map predictions of map locations
        # (3) Path preditions for locations that has navigation path
        # (4) Action predictions for locations that has navigation path
        # (5) Distance prediction for locations that has navigation path

        self.loss_names = [
            'ego_position3d_loss', 'map_position3d_loss',
            'ego_yaw_loss', 'map_yaw_loss', 
            'position3d_loss', 'yaw_loss', 
            'ego_location_loss', 'map_location_loss', 
            'local_path_loss', 'global_path_loss', 
            'dist_loss', 'action_loss', 'loss', 
            'unknown_map_position3d_loss', 'unknown_map_yaw_loss', 
            'known_map_position3d_loss', 'known_map_yaw_loss'
        ]

        # Ground truth
        # (1) egocentric positions
        # (2) Map positions
        # (3) Local, global path
        # (4) Navigation actions
        # (5) Distance to goal
        
                
    def forward(self, data, eval_loss=True):
        # We use two coordinations: 
        #     map_* is the coordination with the map as center
        #     ego_* is the coordination with the agent as center
        # The model predicts the relative position3d/yaw with both the coordination. 
        images = data['images']
        map_position3d_normalized = data['map_positions']
        map_yaws = data['map_yaws'] * np.pi / 180
        location_types = data['location_types']
        # Mask all input positions except for map locations
        map_position_masks = (location_types == 1).to(torch.int64)
        map_yaw_masks = (location_types == 1).to(torch.int64)
        map_yaw_cos_sin = torch.stack([torch.cos(map_yaws), torch.sin(map_yaws), torch.zeros_like(map_yaws)], -1)
        
        # All the below embeddings have the shape: [batch_size, num_locations, 768]
        batch_size, num_locations, *img_shape = images.shape
        image_embedding = self.encode_image(images.reshape(batch_size * num_locations, *img_shape)) \
            .view(batch_size, num_locations, -1)
        
        assert image_embedding.ndim == 3 and image_embedding.shape[-1] == self.hidden_size, \
            f'The image embedding shape {image_embedding.shape} is not valid'
        position3d_embedding = self.position3d_encoder(map_position3d_normalized, map_position_masks)
        yaw_embedding = self.yaw_encoder(map_yaw_cos_sin, map_yaw_masks)
        location_type_embedding = self.location_type_embeds(location_types)
        
        image_embedding = self.embed_layer_norm(image_embedding)
        position3d_embedding = self.embed_layer_norm(position3d_embedding)
        yaw_embedding = self.embed_layer_norm(yaw_embedding)
        location_type_embedding = self.embed_layer_norm(location_type_embedding)
        
        location_embedding = image_embedding + position3d_embedding \
            + yaw_embedding + location_type_embedding
        
        # output_embedding: [batch_size, location_num, 768]
        output_embedding = self.transformer(location_embedding)
        
        if self.contrast_embedding:
            # When using contrast embedding, we use (z_i, z_j) as the input for predicting d(s_i, s_j)
            # First, we find the egocentric embedding. 
            egocentric_boolean_mask = (location_types == 0)  # [batch_size, location_num]
            output_egocentric_embedding = output_embedding[egocentric_boolean_mask]  # [batch_size, 768]
            output_egocentric_embedding = output_egocentric_embedding.unsqueeze(1)  # [batch_size, 1, 768]
                        
            if self.contrast_method == 'subtract':
                contrast_embedding = output_embedding - output_egocentric_embedding
            elif self.contrast_method == 'concat':
                contrast_embedding = torch.cat([output_embedding, output_egocentric_embedding.tile(1, num_locations, 1)], dim=-1)
            else:
                raise NotImplementedError
            
            relu = nn.ReLU()
            contrast_embedding = relu(self.contrast_fc(contrast_embedding))
        
        # ego_positions, map_positions: [batch_size, location_num, 4]
        # local_paths, global_paths: [batch_size, location_num, len_trajectory_pred, 4]
        # action_logits: [batch_size, location_num, len_trajectory_pred * action_num]
        # dist_prediction: [batch_size, location_num]
        map_positions_pred = self.map_decoder(output_embedding)
        if self.contrast_embedding:
            ego_positions_pred = self.ego_decoder(contrast_embedding)
            local_paths_pred = self.local_path_decoder(contrast_embedding)
            global_paths_pred = self.global_path_decoder(contrast_embedding)
            action_logits = self.action_predictor(contrast_embedding).view(
                batch_size, num_locations, self.len_trajectory_pred, self.action_num
            )
            dist_pred = self.dist_predictor(contrast_embedding)[..., 0]
        else:
            ego_positions_pred = self.ego_decoder(output_embedding)
            local_paths_pred = self.local_path_decoder(output_embedding)
            global_paths_pred = self.global_path_decoder(output_embedding)
            action_logits = self.action_predictor(output_embedding).view(
                batch_size, num_locations, self.len_trajectory_pred, self.action_num
            )
            dist_pred = self.dist_predictor(output_embedding)[..., 0]
        
        outputs = {
            'ego_positions_pred': ego_positions_pred, 
            'map_positions_pred': map_positions_pred, 
            'local_paths_pred': local_paths_pred, 
            'global_paths_pred': global_paths_pred, 
            'action_logits': action_logits, 
            'dist_pred': dist_pred
        }
        if self.training or eval_loss:
            # Compute loss
            
            # position_masks_map = position_masks.unsqueeze(-1).to(torch.float32)
            # yaw_masks_map = yaw_masks.unsqueeze(-1).to(torch.float32)
            
            # Compute prediction loss with map as center
            map_position3d_loss = F.mse_loss(
                map_positions_pred[..., :3], 
                map_position3d_normalized
            )
            map_yaw_loss = F.mse_loss(
                map_positions_pred[..., 3:], 
                map_yaw_cos_sin
            )
            
            # Compute prediction loss with agent as center
            device = data['local_path_scale'].device
            ego_position3d_normalized = data['ego_positions']
            ego_yaws = data['ego_yaws'] * np.pi / 180
            local_paths = data['local_path'] / torch.cat([data['local_path_scale'][0], torch.ones([3]).to(device)])
            global_paths = data['global_path'] / torch.cat([data['pos_scale'][0], torch.ones([3]).to(device)])
            dist = data['distance'] / data['pos_scale'][0].norm()
            
            ego_yaw_cos_sin = torch.stack([torch.cos(ego_yaws), torch.sin(ego_yaws), torch.zeros_like(ego_yaws)], -1)
            ego_position3d_loss = F.mse_loss(
                ego_positions_pred[..., :3], 
                ego_position3d_normalized
            )
            ego_yaw_loss = F.mse_loss(
                ego_positions_pred[..., 3:], 
                ego_yaw_cos_sin
            )
            
            position3d_loss = ego_position3d_loss + map_position3d_loss
            yaw_loss = ego_yaw_loss + map_yaw_loss
            
            ego_location_loss = ego_position3d_loss + ego_yaw_loss
            map_location_loss = map_position3d_loss + map_yaw_loss
            
            # Compute loss for navigation paths and actions
            goal_masks = data['goal_masks']  # 1 is goal, 0 is non-goal
            goal_masks_float = goal_masks.to(torch.float32)
            
            local_path_loss = F.mse_loss(
                local_paths_pred * goal_masks_float.unsqueeze(-1).unsqueeze(-1), 
                local_paths * goal_masks_float.unsqueeze(-1).unsqueeze(-1), 
                reduction='none'
            ).mean([-2, -1]).sum() / (goal_masks_float.sum() + 1e-6)

            global_path_loss = F.mse_loss(
                global_paths_pred * goal_masks_float.unsqueeze(-1).unsqueeze(-1), 
                global_paths * goal_masks_float.unsqueeze(-1).unsqueeze(-1), 
                reduction='none'
            ).mean([-2, -1]).sum() / (goal_masks_float.sum() + 1e-6)
            
            dist_loss = F.mse_loss(
                dist_pred * goal_masks_float, 
                dist * goal_masks_float, 
                reduction='sum'
            ) / (goal_masks_float.sum() + 1e-6)
            
            action_loss = F.cross_entropy(
                action_logits.permute(0, 3, 1, 2), 
                data['local_actions'], 
                reduction='none'
            )
            action_loss = (action_loss * goal_masks_float.unsqueeze(-1)) \
                .mean(-1).sum() / (goal_masks_float.sum() + 1e-6)
            
            unknown_masks = (1 - map_position_masks).to(torch.float32)
            unknown_map_position3d_loss = F.mse_loss(
                map_positions_pred[..., :3] * unknown_masks.unsqueeze(-1), 
                map_position3d_normalized * unknown_masks.unsqueeze(-1), 
                reduction='none'
            ).mean(-1).sum() / (unknown_masks.sum() + 1e-6)
            unknown_map_yaw_loss = F.mse_loss(
                map_positions_pred[..., 3:] * unknown_masks.unsqueeze(-1), 
                map_yaw_cos_sin * unknown_masks.unsqueeze(-1), 
                reduction='none'
            ).mean(-1).sum() / (unknown_masks.sum() + 1e-6)
            
            known_masks = map_position_masks.to(torch.float32)
            known_map_position3d_loss = F.mse_loss(
                map_positions_pred[..., :3] * known_masks.unsqueeze(-1), 
                map_position3d_normalized * known_masks.unsqueeze(-1), 
                reduction='none'
            ).mean(-1).sum() / (known_masks.sum() + 1e-6)
            known_map_yaw_loss = F.mse_loss(
                map_positions_pred[..., 3:] * known_masks.unsqueeze(-1), 
                map_yaw_cos_sin * known_masks.unsqueeze(-1), 
                reduction='none'
            ).mean(-1).sum() / (known_masks.sum() + 1e-6)

            if self.disable_auxiliary == 'all':
                loss = action_loss
            elif self.disable_auxiliary == 'disable_position':
                loss = local_path_loss + global_path_loss + dist_loss + action_loss
            elif self.disable_auxiliary == 'disable_path':
                loss = map_yaw_loss + map_position3d_loss + ego_yaw_loss + ego_position3d_loss + action_loss
            else:
                loss = map_yaw_loss + map_position3d_loss + ego_yaw_loss + ego_position3d_loss + \
                    local_path_loss + global_path_loss + dist_loss + action_loss
            loss_info = {
                'ego_position3d_loss': ego_position3d_loss,
                'map_position3d_loss': map_position3d_loss,
                'ego_yaw_loss': ego_yaw_loss,
                'map_yaw_loss': map_yaw_loss,
                'position3d_loss': position3d_loss, 
                'yaw_loss': yaw_loss, 
                'ego_location_loss': ego_location_loss, 
                'map_location_loss': map_location_loss, 
                'local_path_loss': local_path_loss, 
                'global_path_loss': global_path_loss, 
                'dist_loss': dist_loss, 
                'action_loss': action_loss, 
                'loss': loss, 
                'image_embedding_norm': image_embedding.norm(dim=-1).mean(),
                'position3d_embedding_norm': position3d_embedding.norm(dim=-1).mean(),
                'yaw_embedding_norm': yaw_embedding.norm(dim=-1).mean(),
                'location_type_embedding': location_type_embedding.norm(dim=-1).mean(), 
                'unknown_map_position3d_loss': unknown_map_position3d_loss, 
                'unknown_map_yaw_loss': unknown_map_yaw_loss, 
                'known_map_position3d_loss': known_map_position3d_loss, 
                'known_map_yaw_loss': known_map_yaw_loss
            }
            outputs.update(loss_info)
        
        return outputs
        
    def encode_image(self, image) -> torch.Tensor:
        if self.configs['encoder'] == 'MobileNet':
            return self.image_encoder(image)
        elif self.configs['encoder'] == 'MobileNet-pretrained':
            features = self.image_encoder(image)
            features = nn.functional.adaptive_avg_pool2d(features, (1, 1)).flatten(1)
            return self.last_fc_layer(features)
        elif self.configs['encoder'] == 'ViT-classification':
            return self.image_encoder.vit(image).last_hidden_state[:, 0, :]
        elif self.configs['encoder'] == 'MAE':
            # Return the output from the CLS token
            return self.image_encoder(image).last_hidden_state[:, 0, :]
        elif self.configs['encoder'] == 'SAM':
            return self.image_encoder(image, output_hidden_states=True).vision_hidden_states
        elif self.configs['encoder'] == 'resnet-50':
            return self.flatten(self.image_encoder(image).pooler_output)
        else:
            raise ValueError(f'The image encoder {self.configs["encoder"]} is not supported')
        
    def process_image(self, image):
        if self.huggingface_image_processor:
            return self.image_processor(image, return_tensors="pt")['pixel_values']
        else:
            if type(image) is list:
                return torch.stack([self.image_processor(im) for im in image], 0)
            else:
                return self.image_processor(image).unsqueeze(0)

class PositionEncoder(nn.Module):
    def __init__(self, hidden_size=768) -> None:
        super().__init__()
        
        self.hidden_size = hidden_size
        # Another method is to use 2*hidden_size fc layer. 
        self.positional_embedding = nn.Parameter(
            torch.randn(3, self.hidden_size // 2), requires_grad=False
        )
        self.unknown_positional_embedding = nn.Parameter(
            torch.randn(1, self.hidden_size), requires_grad=True
        )
        
    def forward(self, input_coords, input_masks):
        # Make sure that input_coords are normalized to [-1, 1]
        coordinates = input_coords.clone()
        coordinates = coordinates.to(self.positional_embedding.dtype)
        coordinates = 2 * np.pi * coordinates @ self.positional_embedding
        embedding = torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
        embedding = torch.where(
            input_masks[..., None] == 1, 
            embedding, 
            self.unknown_positional_embedding, 
        )
        return embedding

class PositionDecoder(nn.Module):
    def __init__(self, hidden_size=768):
        super().__init__()
        # First 3 dimensions are 3d positions
        # Second 3 dimensions are orientations
        self.predictor = nn.Linear(hidden_size, 6)
        
    def forward(self, latent):
        location_pred = self.predictor(latent)
        location_pred[..., 3:] = F.normalize(
            location_pred[..., 3:].clone(), dim=-1
        )  # normalize the angle prediction
        return location_pred
    
class PathDecoder(nn.Module):
    def __init__(self, hidden_size=768, len_trajectory_pred=5):
        super().__init__() 
        self.len_trajectory_pred = len_trajectory_pred
        self.predictor = nn.Linear(hidden_size, 6 * len_trajectory_pred)
        
    def forward(self, latent):
        path_pred = self.predictor(latent)
        *shape_prefix, last_dim_size = path_pred.shape
        path_pred = path_pred.reshape(
            *shape_prefix, self.len_trajectory_pred, 6
        )
        path_pred[..., :, :3] = torch.cumsum(
            path_pred[..., :, :3], dim=-2
        )  # convert position deltas into waypoints
        path_pred[:, :, 3:] = F.normalize(
            path_pred[:, :, 3:].clone(), dim=-1
        )  # normalize the angle prediction
        return path_pred

if __name__ == '__main__':
    # position_encoder = PositionEncoder()
    # result = position_encoder(
    #     torch.randn(2, 10, 2), 
    #     torch.randint(2, (2, 10))
    # )
    # print(result.shape)
    
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    
    batch_size = 4
    location_num = 40
    
    configs = {
        'encoder': 'MobileNet-pretrained', 
        'len_trajectory_pred': 5, 
        'action_num': 7, 
        'image_size': [85, 64]
    }
    
    model = IGNet(configs).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    image_tensor = model.process_image([image])
    image_tensor = image_tensor.unsqueeze(1).repeat(batch_size, location_num, 1, 1, 1).cuda()
    print('image_tensor mean and std:', image_tensor.mean(), image_tensor.std())
    
    print(model.process_image([image, image, image, image]).shape)
    
    def generate_data():
        start = time.time()
        data = {
            'images': image_tensor, 
            'map_positions': torch.zeros(batch_size, location_num, 3).cuda(), # map positions
            'map_yaws': torch.zeros(batch_size, location_num).cuda(), # map yaws
            'ego_positions': torch.zeros(batch_size, location_num, 3).cuda(), # ego positions
            'ego_yaws': torch.zeros(batch_size, location_num).cuda(), # ego yaws
            'pos_scale': 2000.0 * torch.ones(batch_size, 3).cuda(),
            'local_path_scale': 500.0 * torch.ones(batch_size, 3).cuda(),
            'location_types': torch.randint(4, (batch_size, location_num)).cuda(), 
            'goal_masks': torch.randint(2, (batch_size, location_num)).cuda(), 
            'local_path': torch.rand(batch_size, location_num, 5, 6).cuda(), 
            'global_path': torch.rand(batch_size, location_num, 5, 6).cuda(), 
            'local_actions': torch.randint(7, (batch_size, location_num, 5)).cuda(), 
            'distance': torch.rand(batch_size, location_num).cuda()
        }
        print('data time (train):', time.time() - start)
        return data
    
    def generate_inference_data():
        start = time.time()
        data = {
            'images': image_tensor, 
            'map_positions': torch.zeros(batch_size, location_num, 3).cuda(), # map positions
            'map_yaws': torch.zeros(batch_size, location_num).cuda(), # map yaws
            'pos_scale': torch.tensor([2000.0, 2000.0]).cuda(),
            'location_types': torch.randint(4, (batch_size, location_num)).cuda(), 
        }
        print('data time (inference):', time.time() - start)
        return data
    
    print_model = False
    if print_model:
        print(model)
        for name, parameter in model.named_parameters():
            print(name, parameter.size())

    for _ in range(5):
        data = generate_data()
        start = time.time()
        output = model(data)
        optimizer.zero_grad()
        output['loss'].backward()
        optimizer.step()
        print(f'train time: {time.time() - start}')
    
    model.eval()
    
    with torch.no_grad():
        for _ in range(5):
            data = generate_inference_data()
            start = time.time()
            output = model(data, eval_loss=False)
            print(f'inference time: {time.time() - start}')
    
    num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    num_new_parameters = sum(
        p.numel() for name, p in model.named_parameters() if p.requires_grad and not name.startswith('image_encoder')
    )
    print(f'{num_parameters = }')
    print(f'{num_new_parameters = }')
    
    input('.')