# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
# Adapted for DUPS from AutoFocusFormer

import torch
from .dups_encoder import DUPSEncoder
from .mixres_vit import MixResViT
from .mixres_neighbour import MixResNeighbour

def build_model(config):
    model_type = config.MODEL.TYPE
    if model_type == 'DUPS':
        bb_in_feats = [[None], ["res5"], ["res5", "res4"], ["res5", "res4", "res3"], ["res5", "res4", "res3"],
                       ["res5", "res4"], ["res5"], [None]]
        all_backbones = []
        n_scales = config.MODEL.DUPS.N_RESOLUTION_SCALES
        n_layers = len(config.MODEL.DUPS.NAME)
        min_patch_size = config.MODEL.DUPS.PATCH_SIZES[n_scales - 1]
        for layer_index, name in enumerate(config.MODEL.DUPS.NAME):
            if layer_index == 0:
                first_layer = True
                in_chans = 3
            else:
                first_layer = False
                in_chans = config.MODEL.DUPS.EMBED_DIM[layer_index - 1]
            if layer_index >= n_scales:
                scale = n_layers - layer_index - 1
                patch_sizes = config.MODEL.DUPS.PATCH_SIZES[layer_index:]
                out_features = config.MODEL.DUPS.OUT_FEATURES[-(n_layers - layer_index):]
                #in_chans = sum(config.MODEL.DUPS.EMBED_DIM[-(layer_index + 1):-(n_layers - layer_index)])
                in_chans = config.MODEL.DUPS.EMBED_DIM[layer_index - 1] + config.MODEL.DUPS.EMBED_DIM[n_layers - layer_index - 1]
            else:
                scale = layer_index
                patch_sizes = config.MODEL.DUPS.PATCH_SIZES[:layer_index + 1]
                out_features = config.MODEL.DUPS.OUT_FEATURES[-(layer_index+1):]
            drop_path_rate = config.MODEL.DUPS.DROP_PATH_RATE
            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config.MODEL.DUPS.DEPTHS))]
            drop_path = dpr[sum(config.MODEL.DUPS.DEPTHS[:layer_index]):sum(config.MODEL.DUPS.DEPTHS[:layer_index + 1])]
            if name == 'MixResViT':
                bb = MixResViT(patch_sizes=patch_sizes,
                               n_layers=config.MODEL.DUPS.DEPTHS[layer_index],
                               d_model=config.MODEL.DUPS.EMBED_DIM[layer_index],
                               n_heads=config.MODEL.DUPS.NUM_HEADS[layer_index],
                               mlp_ratio=config.MODEL.DUPS.MLP_RATIO[layer_index],
                               dropout=config.MODEL.DUPS.DROP_RATE[layer_index],
                               drop_path_rate=drop_path,
                               split_ratio=config.MODEL.DUPS.SPLIT_RATIO[layer_index],
                               channels=in_chans,
                               n_scales=n_scales,
                               min_patch_size=min_patch_size,
                               upscale_ratio=config.MODEL.DUPS.UPSCALE_RATIO[layer_index],
                               out_features=out_features,
                               first_layer=first_layer,
                               layer_scale=config.MODEL.DUPS.LAYER_SCALE,
                               num_register_tokens=config.MODEL.DUPS.NUM_REGISTER_TOKENS)
            elif name == 'MixResNeighbour':
                bb = MixResNeighbour(patch_sizes=patch_sizes,
                                     n_layers=config.MODEL.DUPS.DEPTHS[layer_index],
                                     d_model=config.MODEL.DUPS.EMBED_DIM[layer_index],
                                     n_heads=config.MODEL.DUPS.NUM_HEADS[layer_index],
                                     mlp_ratio=config.MODEL.DUPS.MLP_RATIO[layer_index],
                                     dropout=config.MODEL.DUPS.DROP_RATE[layer_index],
                                     drop_path_rate=drop_path,
                                     attn_drop_rate=config.MODEL.DUPS.ATTN_DROP_RATE[layer_index],
                                     split_ratio=config.MODEL.DUPS.SPLIT_RATIO[layer_index],
                                     channels=in_chans,
                                     cluster_size=config.MODEL.DUPS.CLUSTER_SIZE[layer_index],
                                     nbhd_size=config.MODEL.DUPS.NBHD_SIZE[layer_index],
                                     n_scales=n_scales,
                                     keep_old_scale=config.MODEL.DUPS.KEEP_OLD_SCALE,
                                     scale=scale,
                                     add_image_data_to_all=config.MODEL.DUPS.ADD_IMAGE_DATA_TO_ALL,
                                     min_patch_size=min_patch_size,
                                     upscale_ratio=config.MODEL.DUPS.UPSCALE_RATIO[layer_index],
                                     layer_scale=config.MODEL.DUPS.LAYER_SCALE,
                                     out_features=out_features,
                                     first_layer=first_layer)
            else:
                raise NotImplementedError(f"Unkown model: {name}")
            all_backbones.append(bb)
        model = DUPSEncoder(backbones=all_backbones,
                                      backbone_dims=config.MODEL.DUPS.EMBED_DIM,
                                      out_dim=config.MODEL.DUPS.OUT_DIM,
                                      all_out_features=config.MODEL.DUPS.OUT_FEATURES,
                                      n_scales=config.MODEL.DUPS.N_RESOLUTION_SCALES,
                                      num_classes=config.MODEL.NUM_CLASSES,
                                      bb_in_feats=bb_in_feats,
                                      aux_loss=config.MODEL.DUPS.AUX_LOSS)
    else:
        raise NotImplementedError(f"Unkown model: {model_type}")

    return model
