import os.path as osp
import torch
import torch.nn as nn
import logging
from typing import Optional
import segmentation_models_pytorch as smp

from .base import ACT_REGISTRY
from config.registry import Registry
from utils.checkpoint import load_checkpoint

logger = logging.getLogger(__name__)

NETWORK_REGISTRY = Registry("model")


@NETWORK_REGISTRY.register("R50Fpn")
def R50Fpn(cfg) -> smp.FPN:
    """Constucts a big unet model"""
    model = smp.FPN(
        encoder_name='resnet50',
        in_channels=cfg.MODEL.INPUT_CHANNELS,
        classes=cfg.MODEL.NUM_CLASSES
    )
    setattr(model, 'act', ACT_REGISTRY.get(cfg.MODEL.ACT_FUNC))
    setattr(model, 'multi_label', cfg.MODEL.MULTI_LABEL)
    setattr(model, 'n_classes', cfg.MODEL.NUM_CLASSES)
    setattr(model, "branch_number", 1)
    return model


@NETWORK_REGISTRY.register("R50Unet")
def build_R50Unet(cfg) -> smp.Unet:
    model = smp.Unet(
        encoder_name="resnet50",
        decoder_attention_type=None,
        in_channels=cfg.MODEL.INPUT_CHANNELS,
        classes=cfg.MODEL.NUM_CLASSES
    )
    setattr(model, "act", ACT_REGISTRY.get(cfg.MODEL.ACT_FUNC))
    setattr(model, 'multi_label', cfg.MODEL.MULTI_LABEL)
    setattr(model, 'n_classes', cfg.MODEL.NUM_CLASSES)
    setattr(model, "branch_number", 1)
    return model


def build_model(cfg, model_path : Optional[str] = None):
    """
    Builds the segmentation model.
    Args:
        cfg : configs to build the backbone. Detains can ben seen in configs/defaults.py
    """

    arch = cfg.MODEL.ARCH
    logger.info("Construct model : {}".format(arch))
    model = NETWORK_REGISTRY.get(arch)(cfg)

    if model_path:
        load_checkpoint(model_path, model, cfg.DEVICE)
        logger.info("Successfully load weights from {}".format(model_path))
    elif cfg.MODEL.PRETRAINED:
        load_checkpoint(cfg.MODEL.PRETRAINED, model, cfg.DEVICE)
        logger.info("Successfully load weights from {}".format(cfg.MODEL.PRETRAINED))

    return model
