import yaml
from pathlib import Path
import timm
import torch

from .projectors import create_intrinsic_model


def create_model(model_name=None, num_classes=None, base_width=None, in_chans=None,
                 seed=None, intrinsic_dim=0, intrinsic_mode='filmrdkron',global_param = None,
                 cfg_path=None, transfer=False, device_id=None, log_dir=None):

  device = torch.device(f'cuda:{device_id}') if isinstance(device_id, int) else None

  ## Prepare configurations.
  net_cfg, intrinsic_cfg = None, None
  base_ckpt_path, id_ckpt_path = None, None


  net_cfg = dict(model_name=model_name, num_classes=num_classes, in_chans=in_chans)
  if base_width is not None:
    net_cfg['base_width'] = base_width

  ## Always try setup, but avoid overriding existing intrinsic config.
  if intrinsic_cfg is None and intrinsic_dim > 0:
    intrinsic_cfg = dict(intrinsic_dim=intrinsic_dim, intrinsic_mode=intrinsic_mode, seed=seed)

  ## Load base model.
  base_net = timm.create_model(**net_cfg, checkpoint_path=base_ckpt_path)


  base_net = base_net.to(device)
  

  return base_net



    


    
      
        
  


