import torch
import torch.nn as nn
from models.embeddings import ConvEmbedding, VideoConvEmbedding
from models.backbones import ResnetBackbone, FCNResnetBackbone, VideoResnetBackbone
from models.heads import ClsHead, ProjectHead, FCNHead, FCNProjectHead, VideoClsHead, VideoProjectHead

FEAT_DIM_DICT = {
    'resnet18':512,
    'resnet50x1':2048,
    'resnet50x2':4096,
    'resnet50x4':8192,
}

STRUCT_DICT = {
    'resnet18':[[3, 7, 2], [64, 7, 2]],
    'resnet50x1':[[3, 7, 2], [64, 7, 2]],
    'resnet50x2':[[3, 7, 2], [128, 7, 2]],
    'resnet50x4':[[3, 7, 2], [256, 7, 2]],
}

def init_weights(model, init_type=''):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

def build_models(cfg):
    """Build a model from cfg (dict)
    """
    models = {}
    for idx, (model_type, backbone_type) in enumerate(zip(cfg.type, cfg.backbone_type)):
        # build embedding layer
        if model_type in ['image', 'sketch', 'depth', 'image_seg', 'depth_seg', 'video', 'audio']:
            structures = STRUCT_DICT[backbone_type]  #[[3, 7, 2], [64, 7, 2]]  # [[channel, kernel_size, stride,],...] 
        else:
            raise ValueError(f'Unsupported model type {model_type}')
        
        if model_type in ['video']:
            models[f'psi_{idx}'] = VideoConvEmbedding(structures)
        else:
            models[f'psi_{idx}'] = ConvEmbedding(structures)
        
        # build backbone layer
        if model_type in ['image_seg', 'depth_seg']:
            models[f'phi_{idx}'] = FCNResnetBackbone(backbone_type)
        elif model_type in ['video']:
            models[f'phi_{idx}'] = VideoResnetBackbone(backbone_type)
        else:
            models[f'phi_{idx}'] = ResnetBackbone(backbone_type)
        
        # build function layer
        if model_type in ['image_seg', 'depth_seg']:
            models[f'f_{idx}'] = FCNHead(FEAT_DIM_DICT[backbone_type], cfg.num_classes[idx], cfg.input_shapes[idx])
        elif model_type in ['video']:
            models[f'f_{idx}'] = VideoClsHead(FEAT_DIM_DICT[backbone_type], cfg.num_classes[idx])
        else:
            models[f'f_{idx}'] = ClsHead(FEAT_DIM_DICT[backbone_type], cfg.num_classes[idx])

        if model_type in ['image_seg', 'depth_seg']:
            models[f'g_{idx}'] = FCNProjectHead(FEAT_DIM_DICT[backbone_type], cfg.proj_dims[idx])
        elif model_type in ['video']:
            models[f'g_{idx}'] = VideoProjectHead(FEAT_DIM_DICT[backbone_type], cfg.proj_dims[idx])
        else:
            models[f'g_{idx}'] = ProjectHead(FEAT_DIM_DICT[backbone_type], cfg.proj_dims[idx])
        
        for key in [f'psi_{idx}',f'phi_{idx}', f'f_{idx}', f'g_{idx}']:
            init_weights(models[key])

    return models
    
from models.losses import CrossEntropyLoss, L1Loss, MSELoss, InfoNCELoss, CoInfoNCELoss, CMDLoss, CLIPLoss, CLIPCMDLoss, SegCELoss
def build_losses(cfg):
    """Build losses used for the model
    """
    losses = {}
    for idx, (loss_type, args) in enumerate(zip(cfg.type, cfg.args)):
        if loss_type == 'CE':
            loss = CrossEntropyLoss()
        elif loss_type == 'SegCE':
            loss = SegCELoss()
        elif loss_type == 'L1':
            loss = L1Loss()
        elif loss_type == 'MSE':
            loss = MSELoss()
        elif loss_type == 'InfoNCE':
            loss = InfoNCELoss(*args)
        elif loss_type == 'CoInfoNCE':
            loss = CoInfoNCELoss(*args)
        elif loss_type == 'CMD':
            loss = CMDLoss(*args)
        elif loss_type == 'CLIPCMD':
            loss = CLIPCMDLoss(*args)
        elif loss_type == 'CLIP':
            loss = CLIPLoss(*args)
        else:
            raise ValueError(f'Unsupported loss type {loss_type}')
        losses[f'task_{idx}'] = loss
    return losses

