from typing import Optional, Tuple, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import torchvision.models as models
import einops
import math

LOG_SIG_MAX = 2
LOG_SIG_MIN = -5
MEAN_MIN = -9.0
MEAN_MAX = 9.0



def set_parameter_requires_grad(model, requires_grad):
    for name, child in model.named_children():
        for param in child.parameters():
            param.requires_grad = requires_grad
            
            
def freeze_params(model):
    set_parameter_requires_grad(model, requires_grad=False)
    

def replace_submodules(
        root_module: nn.Module, 
        predicate: Callable[[nn.Module], bool], 
        func: Callable[[nn.Module], nn.Module]) -> nn.Module:
    """
    predicate: Return true if the module is to be replaced.
    func: Return new module to use.
    """
    if predicate(root_module):
        return func(root_module)

    bn_list = [k.split('.') for k, m 
        in root_module.named_modules(remove_duplicate=True) 
        if predicate(m)]
    for *parent, k in bn_list:
        parent_module = root_module
        if len(parent) > 0:
            parent_module = root_module.get_submodule('.'.join(parent))
        if isinstance(parent_module, nn.Sequential):
            src_module = parent_module[int(k)]
        else:
            src_module = getattr(parent_module, k)
        tgt_module = func(src_module)
        if isinstance(parent_module, nn.Sequential):
            parent_module[int(k)] = tgt_module
        else:
            setattr(parent_module, k, tgt_module)
    # verify that all BN are replaced
    bn_list = [k.split('.') for k, m 
        in root_module.named_modules(remove_duplicate=True) 
        if predicate(m)]
    assert len(bn_list) == 0
    return root_module


class SpatialSoftmax(nn.Module):
    def __init__(self, num_rows: int, num_cols: int, temperature: Optional[float] = None):
        """
        Computes the spatial softmax of a convolutional feature map.
        Read more here:
        "Learning visual feature spaces for robotic manipulation with
        deep spatial autoencoders." Finn et al., http://arxiv.org/abs/1509.06113.
        :param num_rows:  size related to original image width
        :param num_cols:  size related to original image height
        :param temperature: Softmax temperature (optional). If None, a learnable temperature is created.
        """
        super(SpatialSoftmax, self).__init__()
        self.num_rows = num_rows
        self.num_cols = num_cols
        grid_x, grid_y = torch.meshgrid(
            torch.linspace(-1.0, 1.0, num_cols), torch.linspace(-1.0, 1.0, num_rows), indexing="ij"
        )
        x_map = grid_x.reshape(-1)
        y_map = grid_y.reshape(-1)
        self.register_buffer("x_map", x_map)
        self.register_buffer("y_map", y_map)
        if temperature:
            self.register_buffer("temperature", torch.ones(1) * temperature)
        else:
            self.temperature = Parameter(torch.ones(1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        x = x.contiguous().view(-1, h * w)  # batch, C, W*H
        softmax_attention = F.softmax(x / self.temperature, dim=1)  # batch, C, W*H
        expected_x = torch.sum(self.x_map * softmax_attention, dim=1, keepdim=True)
        expected_y = torch.sum(self.y_map * softmax_attention, dim=1, keepdim=True)
        expected_xy = torch.cat((expected_x, expected_y), 1)
        self.coords = expected_xy.view(-1, c * 2)
        return self.coords  # batch, C*2



class BesoResNetEncoder(nn.Module):

    def __init__(
        self,
        latent_dim: int = 128,
        pretrained: bool = False,
        freeze_backbone: bool = False,
        use_mlp: bool = True,
        device: str = 'cuda:0'
    ):
        super(BesoResNetEncoder, self).__init__()
        self.latent_dim = latent_dim
        backbone = models.resnet18(pretrained=pretrained)
        n_inputs = backbone.fc.in_features
        modules = list(backbone.children())[:-1]
        self.backbone = nn.Sequential(*modules)
        if freeze_backbone:
            freeze_params(self.backbone)
        
        # subsitute norm for ema diffusion stuff
        replace_submodules(
                root_module=self.backbone,
                predicate=lambda x: isinstance(x, nn.BatchNorm2d),
                func=lambda x: nn.GroupNorm(
                    num_groups=x.num_features//16, 
                    num_channels=x.num_features)
            )
        self.use_mlp = use_mlp
        if self.use_mlp:
            self.fc_layers = nn.Sequential(nn.Linear(n_inputs, latent_dim))

    def conv_forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        return torch.flatten(x, start_dim=1)

    def forward(self, x):
        batch_size = len(x)
        t_steps = 1
        time_series = False
        
        if len(x.shape) == 5:
            t_steps = x.shape[1]
            x = einops.rearrange(x, 'b t n x_dim y_dim -> (b t) n x_dim y_dim')
            # print(f'After rearrange x shape: {x.shape}')
            time_series = True
        
        if len(x.shape) == 2:
            x = x.unsqueeze(1)

        x = self.conv_forward(x)
        if self.use_mlp:
            x = self.fc_layers(x)
        
        if time_series:
            x = einops.rearrange(x, '(b t) d -> b t d', b=batch_size, t=t_steps, d=self.latent_dim)        
        return x


def depth_to_pointcloud (depth:torch.Tensor):
    '''
    Convert depth to point cloud
    depth -> [batch_size, 200, 200]
    point cloud -> [batch_size, 200, 200, 3]
    '''
    device = depth.device
    batch_size = depth.shape[0]

    # Camera Parameter
    fov_deg=10
    width=200
    height=200
    fov_rad = math.radians(fov_deg)
    fx = width / (2 * math.tan(fov_rad / 2))
    fy = fx
    cx = width / 2
    cy = height / 2

    u = torch.arange(0, width, device=device)
    v = torch.arange(0, height, device=device)
    grid_u, grid_v = torch.meshgrid(u, v, indexing='xy')  # shape: [H, W]

    z = depth  # [H, W]
    x = (grid_u - cx) * z / fx
    y = (grid_v - cy) * z / fy

    pointcloud = torch.stack([x, y, z], dim=-1).reshape(batch_size, -1, 3)

    return pointcloud

class PointNet(nn.Module):
    '''
    Input: point cloud -> [batch_size, N, 3]
    Output: feature -> [batch_size, out_features]
    '''
    def __init__(
        self,
        N = 40000,
        out_features = 32
    ):
        super(PointNet, self).__init__()

        self.num_point = 256
        self.reduce = nn.Linear(N, self.num_point)

        self.conv1 = nn.Conv2d(1, 64, (1, 3))      
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, (1, 1))
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, (1, 1))
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, (1, 1))
        self.bn4 = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(128, 1024, (1, 1))
        self.bn5 = nn.BatchNorm2d(1024)

        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, out_features)
        #self.dropout = nn.Dropout(0.3)


    def forward(self, x):
        x = depth_to_pointcloud(x)
        x = torch.permute(x, (0,2,1))
        x = self.reduce(x)
        x = torch.permute(x, (0,2,1))
        
        x = x.unsqueeze(1)  # Add channel dimension, [B, 1, N, 3]
        x = F.relu(self.bn1(self.conv1(x)))         # [B, 64, N, 1]
        x = F.relu(self.bn2(self.conv2(x)))         # [B, 64, N, 1]
        x = F.relu(self.bn3(self.conv3(x)))         # [B, 64, N, 1]
        x = F.relu(self.bn4(self.conv4(x)))         # [B, 128, N, 1]
        x = F.relu(self.bn5(self.conv5(x)))         # [B, 1024, N, 1]

        x = F.max_pool2d(x, kernel_size=(self.num_point, 1))  # Global pooling
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        #x = self.dropout(x)
        x = self.fc3(x)
        return x

class RgbdEncoder(nn.Module):
    def __init__ (
        self,
        in_size = 224,
        N = 40000,
        latent_dim : int = 64,
        device: str = 'cuda:0'
    ):
        super(RgbdEncoder, self).__init__()
        self.latent_dim = latent_dim
        #self.rgb_reduce = nn.Conv2d(3, 3, (4, 4), stride=(4, 4))
        self.rgb_down = nn.Linear(in_size*in_size, 32*32)
        #self.rgb_down1 = nn.Linear(in_size, 32)
        #self.rgb_down2 = nn.Linear(in_size, 32)
        self.RbgEncoder = BesoResNetEncoder(latent_dim=latent_dim)      
        self.DepthEncoder = PointNet(N, latent_dim)
        self.fusion_fc = nn.Sequential(
            nn.Linear(latent_dim * 2, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim)
        )
    
    def forward(self, rgb, pc):
        B = rgb.shape[0]
        # RGB → CNN → [B, latent_dim]
        
        '''
        rgb = self.rgb_down1(rgb)
        rgb = torch.permute(rgb, (0,1,3,2))
        rgb = self.rgb_down2(rgb)
        rgb = torch.permute(rgb, (0,1,3,2))
        '''

        rgb = torch.reshape(rgb, (rgb.shape[0], rgb.shape[1], -1))
        rgb = self.rgb_down(rgb)
        rgb = torch.reshape(rgb, (rgb.shape[0],rgb.shape[1], 32, 32))

        feat_rgb = self.RbgEncoder(rgb)

        # PointCloud → PointNet → [B, latent_dim]
        feat_pc = self.DepthEncoder(pc)

        # cat
        feat = torch.cat([feat_rgb, feat_pc], dim=1)  # [B, latent_dim]
        return self.fusion_fc(feat)  # [B, latent_dim]
        

class concept_mlp(nn.Module):
    def __init__(self, in_feature, out_feature):
        super(concept_mlp, self).__init__()
        self.fc1 = torch.nn.Linear(in_feature, 512)
        self.fc2 = torch.nn.Linear(512,1024)
        self.fc3 = torch.nn.Linear(1024, out_feature)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)
    
class concept_mlp10242048(nn.Module):
    def __init__(self, in_feature, out_feature):
        super(concept_mlp10242048, self).__init__()
        self.fc1 = torch.nn.Linear(in_feature, 512)
        self.fc2 = torch.nn.Linear(512,1024)
        self.fc3 = torch.nn.Linear(1024,2048)
        self.fc4 = torch.nn.Linear(2048, out_feature)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return self.fc4(x)
    
class concept_mlp2048(nn.Module):
    def __init__(self, in_feature, out_feature):
        super(concept_mlp2048, self).__init__()
        self.fc1 = torch.nn.Linear(in_feature, 512)
        self.fc2 = torch.nn.Linear(512,2048)
        self.fc3 = torch.nn.Linear(2048, out_feature)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)
    
class concept_mlp_Norm(nn.Module):
    def __init__(self, in_feature, out_feature, dropout_p=0.1):
        super(concept_mlp_Norm, self).__init__()
        self.fc1 = nn.Linear(in_feature, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, out_feature)

        self.norm1 = nn.LayerNorm(512)
        self.norm2 = nn.LayerNorm(1024)
        self.dropout = nn.Dropout(p=dropout_p)

    def forward(self, x):
        x = self.dropout(F.relu(self.norm1(self.fc1(x))))
        x = self.dropout(F.relu(self.norm2(self.fc2(x))))
        return self.fc3(x)

class concept_para_Norm(nn.Module):
    def __init__(self, in_feature, out_feature):
        super(concept_para_Norm, self).__init__()
        self.fc = torch.nn.Linear(in_feature, out_feature)
        self.norm = torch.nn.LayerNorm(out_feature)
    def forward(self, x):
        return torch.tanh(self.norm(F.relu(self.fc(x))))
