from .vmamba import VSSM
import torch
from torch import nn
from torch.hub import load_state_dict_from_url

class VMUNet(nn.Module):
    def __init__(self, 
                 input_channel=3, 
                 num_classes=1,
                 depths=[2, 2, 9, 2], 
                 depths_decoder=[2, 9, 2, 2],
                 drop_path_rate=0.2,
                 load_ckpt_path=None,
                ):
        super().__init__()

        self.load_ckpt_path = load_ckpt_path
        self.num_classes = num_classes

        self.vmunet = VSSM(in_chans=input_channel,
                           num_classes=num_classes,
                           depths=depths,
                           depths_decoder=depths_decoder,
                           drop_path_rate=drop_path_rate,
                        )
    
    def forward(self, x):
        if x.size()[1] == 1:
            x = x.repeat(1,3,1,1)
        logits = self.vmunet(x)
        return logits

    
    def load_from(self):
        pretrained_path = "[URL]"
        model_dict = self.vmunet.state_dict()
        
        modelCheckpoint = load_state_dict_from_url(pretrained_path, progress=True)
        
        
        pretrained_dict = modelCheckpoint['model']
        # 
        new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
        model_dict.update(new_dict)
        # ，
        print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict)))
        self.vmunet.load_state_dict(model_dict)

        not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()]
       # print('Not loaded keys:', not_loaded_keys)
        print("encoder loaded finished!")

        model_dict = self.vmunet.state_dict()

        pretrained_odict = modelCheckpoint['model']
        pretrained_dict = {}
        for k, v in pretrained_odict.items():
            if 'layers.0' in k: 
                new_k = k.replace('layers.0', 'layers_up.3')
                pretrained_dict[new_k] = v
            elif 'layers.1' in k: 
                new_k = k.replace('layers.1', 'layers_up.2')
                pretrained_dict[new_k] = v
            elif 'layers.2' in k: 
                new_k = k.replace('layers.2', 'layers_up.1')
                pretrained_dict[new_k] = v
            elif 'layers.3' in k: 
                new_k = k.replace('layers.3', 'layers_up.0')
                pretrained_dict[new_k] = v
        # 
        new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
        model_dict.update(new_dict)
        # ，
        print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict)))
        self.vmunet.load_state_dict(model_dict)
        
        # (keys)
        not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()]
        #print('Not loaded keys:', not_loaded_keys)
        print("decoder loaded finished!")
            
            
def vmunet(num_classes, input_channel=3):
    model = VMUNet(input_channel=input_channel, num_classes=num_classes)
    model.load_from()
    return model            

            
            