import os
import networksvit.unet_adaptive_bins
import torch

from collections import OrderedDict
import networks, networksvit
from SGdepth.sgdepth import SGDepthCommon, SGDepthDepth
import lite_mono
import newcrfs


# When using litemono, set --scale 0 1 2


def __get_2e2_model_class(model):
    factory = {'newcrfs': newcrfs.NewCRFDepth,
               'adabins': networksvit.unet_adaptive_bins.UnetAdaptiveBins}
    return factory[model]

def __get_encoder_decoder_model_class(model):
    encoder_factory = {
        'monodepth2': networks.ResnetEncoder,
        'sgdepth': SGDepthCommon,
        'hrdepth': networks.ResnetEncoder,
        'monovit': networksvit.mpvit_small,
        'litemono': lite_mono.LiteMono,
    }

    decoder_factory = {
        'monodepth2': networks.DepthDecoder,
        'sgdepth': SGDepthDepth,
        'hrdepth': networks.HRDepthDecoder,
        'monovit': networksvit.DepthDecoder,
        'litemono': lite_mono.DepthDecoder,
    }
    return encoder_factory[model], decoder_factory[model]

def __get_model_class(model='monodepth2'):
    if model in ['newcrfs', 'adabins']:
        return __get_2e2_model_class(model)
    return __get_encoder_decoder_model_class(model)

def __init_monodepth2(model):
    encoder, decoder = __get_model_class(model)
    encoder = encoder(num_layers=18, pretrained=False)
    decoder = decoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))
    return encoder, decoder

def __init_sgdepth(model):
    encoder, decoder = __get_model_class(model)
    encoder = encoder(num_layers=18, split_pos=1, grad_scales=(1., 0.))
    decoder = decoder(encoder, 4)
    return encoder, decoder

def __init_hrdepth(model):
    encoder, decoder = __get_model_class(model)
    encoder = encoder(18, pretrained=False)
    decoder = decoder(num_ch_enc=encoder.num_ch_enc)
    return encoder, decoder

def __init_monovit(model):
    encoder, decoder = __get_model_class(model)
    encoder = encoder()
    decoder = decoder()
    return encoder, decoder

def __init_litemono(model):
    encoder, decoder = __get_model_class(model)
    encoder = encoder(model='lite-mono-small')
    decoder = decoder(num_ch_enc=encoder.num_ch_enc, scales=range(3))
    return encoder, decoder

def __init_newcrfs(model):
    newcrfs = __get_model_class(model)
    newcrfs = newcrfs(version='large07', inv_dept=False, max_depth=80., min_depth=1e-3)
    return newcrfs

def __init_adabins(model):
    adabins = __get_model_class(model)
    adabins = adabins.build(n_bins=256, min_val=1e-3, max_val=80., norm='linear')
    return adabins


def __init_model(model):
    models = {
        'monodepth2': __init_monodepth2,
        'sgdepth': __init_sgdepth,
        'hrdepth': __init_hrdepth,
        'monovit': __init_monovit,
        'litemono': __init_litemono,
        'newcrfs': __init_newcrfs
    }
    model_class = models[model]
    return model_class(model)

def __load_sgdepth(model_path):
    encoder, decoder = __init_model('sgdepth')
    model_path = os.path.join(model_path, 'model.pth')
    encoder_dict = OrderedDict()
    depth_decoder_dict = OrderedDict()
    state_dict = torch.load(model_path)
    for name, value in state_dict.items():
        if 'common' in name:
            new_name = name.split('.')
            new_name = new_name[1:]
            new_name = '.'.join(new_name)
            encoder_dict[new_name] = value
        elif 'depth' in name:
            new_name = name.split('.')
            new_name = new_name[1:]
            new_name = '.'.join(new_name)
            depth_decoder_dict[new_name] = value

    encoder.load_state_dict(encoder_dict)
    decoder.load_state_dict(depth_decoder_dict)
    return encoder, decoder

def __load_monodepth2(model_path):
    encoder_path = os.path.join(model_path, 'encoder.pth')
    decoder_path = os.path.join(model_path, 'depth.pth')
    encoder, decoder = __init_model('monodepth2')

    encoder_dict = torch.load(encoder_path)
    encoder.load_state_dict({k: v for k, v in encoder_dict.items() if k in encoder.state_dict()})
    
    loaded_dict = torch.load(decoder_path)
    decoder.load_state_dict(loaded_dict)
    return encoder, decoder

def __load_hrdepth(model_path):
    encoder_path = os.path.join(model_path, 'encoder.pth')
    decoder_path = os.path.join(model_path, 'depth.pth')
    encoder, decoder = __init_model('hrdepth')

    encoder_dict = torch.load(encoder_path)
    decoder_dict = torch.load(decoder_path)

    encoder.load_state_dict({k: v for k, v in encoder_dict.items() if k in encoder.state_dict()})
    decoder.load_state_dict(decoder_dict)
    return encoder, decoder
   
def __load_monovit(model_path):
    encoder_path = os.path.join(model_path, 'encoder.pth')
    decoder_path = os.path.join(model_path, 'depth.pth')
    encoder, decoder = __init_model('monovit')

    encoder_dict = torch.load(encoder_path)
    decoder_dict = torch.load(decoder_path)

    encoder.load_state_dict({k: v for k, v in encoder_dict.items() if k in encoder.state_dict()})
    decoder.load_state_dict(decoder_dict)
    return encoder, decoder

def __load_litemono(model_path):
    encoder_path = os.path.join(model_path, 'encoder.pth')
    decoder_path = os.path.join(model_path, 'depth.pth')
    encoder, decoder = __init_model('litemono')

    encoder_dict = torch.load(encoder_path)
    decoder_dict = torch.load(decoder_path)

    encoder.load_state_dict({k: v for k, v in encoder_dict.items() if k in encoder.state_dict()})
    decoder.load_state_dict({k: v for k, v in decoder_dict.items() if k in decoder.state_dict()})
    return encoder, decoder

def __load_newcrfs(model_path):
    ckpt_path  = os.path.join(model_path, 'model_kittieigen.ckpt')
    model_ckpt = torch.load(ckpt_path)
    model = __init_model('newcrfs')
    state_dict = OrderedDict()
    for k, v in model_ckpt['model'].items():
        new_key = ".".join(k.split(".")[1:])
        state_dict[new_key] = v
    model.load_state_dict(state_dict)
    return model

def __load_adabins(model_path):
    ckpt_path = os.path.join(model_path, 'AdaBins_kitti.pt')
    model_ckpt = torch.load(ckpt_path)
    model = __init_adabins('adabins')
    state_dict = OrderedDict()
    for k, v in model_ckpt['model'].items():
        if k.startswith('module.'):
            k_ = k.replace('module.', '')
            state_dict[k_] = v
        else:
            state_dict[k] = v

    modified = {}  # backward compatibility to older naming of architecture blocks
    for k, v in state_dict.items():
        if k.startswith('adaptive_bins_layer.embedding_conv.'):
            k_ = k.replace('adaptive_bins_layer.embedding_conv.',
                           'adaptive_bins_layer.conv3x3.')
            modified[k_] = v
            # del load_dict[k]

        elif k.startswith('adaptive_bins_layer.patch_transformer.embedding_encoder'):

            k_ = k.replace('adaptive_bins_layer.patch_transformer.embedding_encoder',
                           'adaptive_bins_layer.patch_transformer.embedding_convPxP')
            modified[k_] = v
            # del load_dict[k]
        else:
            modified[k] = v
    model.load_state_dict(modified)
    return model

def model_factory(model):
    factory = {
        'monodepth2': __load_monodepth2,
        'sgdepth': __load_sgdepth,
        'hrdepth': __load_hrdepth,
        'monovit': __load_monovit,
        'litemono': __load_litemono,
        'newcrfs': __load_newcrfs,
        'adabins': __load_adabins
    }
    return factory[model]

def encoder_factory(model, *args, **kwargs):
    factory = {
        'monodepth2': networks.ResnetEncoder,
        'sgdepth': SGDepthCommon,
        'hrdepth': networks.ResnetEncoder,
        'monovit': networksvit.mpvit_small,
        'litemono': lite_mono.LiteMono,
    }
    encoder = factory[model]
    return encoder(*args, **kwargs)

if __name__ == '__main__':
    # model = networksvit.unet_adaptive_bins.UnetAdaptiveBins.build(
    #     n_bins=256, min_val=1e-2, max_val=80.0, norm='linear')
    import torch.nn as nn
    import torch.nn.functional as F

    mname = 'adabins'
    model = model_factory(mname)
    new_crfs = model(f'./weights/{mname}')
    new_crfs.eval()
    
    inp = torch.randn(1, 3, 1241, 376)
    out = new_crfs(inp)
    # inter = F.interpolate(out, [192, 512], mode='bilinear', align_corners=False)
    print(type(out))
    print(out[0].size(), out[1].size())
    
