import torch
import os
from utils.conf import create_path 

def _get_tag(args):
    tag = 'dis' if not args.joint else 'joint'
    if args.task == 'product' and args.model in ['mnistsl', 'mnistslrec']:
        tag = tag+'-prod'
    if args.task == 'multiop':
        tag = tag+'-multiop'
    return tag

def create_load_ckpt(model, args):
    create_path('data/runs')
    create_path('data/ckpts')
    
    tag = _get_tag(args)

    PATH = f'data/runs/{args.dataset}-{args.model}-{tag}-start.pt'
    
    if args.checkin is not None:
        model.load_state_dict(torch.load(args.checkin))  
    elif os.path.exists(PATH):
        print('Loaded',PATH, '\n')
        model.load_state_dict(torch.load(PATH))
    else:
        print('Created',PATH, '\n')
        torch.save(model.state_dict(), PATH)

    return model      

def save_model(model, args):
    create_path('data/ckpts')
    tag = _get_tag(args)

    if not args.active_learning:
        PATH = f'data/ckpts/{args.dataset}-{args.model}-{tag}-{args.seed}-end.pt'
    else:
        PATH = f'data/ckpts/{args.dataset}-{args.model}-active-learning-{tag}-{args.seed}-end.pt'
    
    if args.checkout:
        print('Saved',PATH, '\n')
        torch.save(model.state_dict(), PATH)

def get_model_name(args):
    return f"dset_{args.dataset}-model_{args.model}-tag_{_get_tag(args)}"

def load_checkpoint(model, args, checkin=None):
    create_path('data/ckpts')
    tag = _get_tag(args)

    if checkin is not None: PATH = checkin
    else:                   PATH = f'data/ckpts/{args.dataset}-{args.model}-{tag}-{args.seed}-end.pt'
    
    if not os.path.exists(PATH):
        raise ValueError(f"You have to train the model first, missing {PATH}")

    print('Loaded',PATH, '\n')
    print("Path", PATH)
    model.load_state_dict(torch.load(PATH))

    return model        
