import os
import copy
import torch

from collections import OrderedDict

from geom.models import ULIP_models as models

def load_geom_encoder(args, pretrained=True, frozen=False):
    ckpt = torch.load(args.ulip_ckpt, map_location='cpu')
    state_dict = OrderedDict()
    for k, v in ckpt['state_dict'].items():
        state_dict[k.replace('module.', '')] = v

    # create model
    print("=> creating model: {}".format(args.ulip_model))
    model = getattr(models, args.ulip_model)(args=args)
    if pretrained:
        model.load_state_dict(state_dict, strict=True)
        print("=> loaded resume checkpoint '{}'".format(args.ulip_ckpt))
    else:
        print("=> new model without pretraining")
    
    point_encoder = copy.deepcopy(model.point_encoder)
    pc_projection = copy.deepcopy(model.pc_projection)
    del(model)
    
    if frozen:
        for params in point_encoder.parameters():
            params.requires_grad = False
        pc_projection.requires_grad = False

    return point_encoder, pc_projection
 