import os
import torch
from models.critic_model import Critic
from utils.checkpoint_util import checkpoints_dic,get_args,find_latest_checkpoint



def initialize_q_network(state_dim, act_dim, iql_q_hiddens, iql_layernorm, gym_name ,load_state_dic=True):

    qf = Critic(state_dim, act_dim, 256, iql_q_hiddens,
                iql_layernorm)
    if load_state_dic:
        model_dir_path = os.path.join('critic_checkpoints/iql', gym_name, str(0))
        full_file_path = os.path.join(model_dir_path, 'qf_1000000.pth')

        if not os.path.isdir(model_dir_path):
            raise FileNotFoundError(f"Directory does not exist: {model_dir_path}")
        if not os.path.isfile(full_file_path):
            raise FileNotFoundError(f"File does not exist: {full_file_path}")

        qf.load_state_dict(torch.load(full_file_path))
        print(f"Model loaded from {full_file_path}")
    return qf


def load_aqdt_q_network(state_dim, act_dim, iql_q_hiddens, iql_layernorm, gym_name,device,wandb_run=None):
    if wandb_run is None:
        wandb_run=checkpoints_dic[gym_name][0]
    args_list = get_args([wandb_run])
    variant= args_list[0]
    variant = vars(variant)
    exp_prefix = variant['name']
    save_path="aqdt_runs"
    checkpoint_path = os.path.join(save_path, exp_prefix)
    checkpoint_path = find_latest_checkpoint(checkpoint_path)
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoints = torch.load(checkpoint_path, map_location=device)
    critic= initialize_q_network(state_dim, act_dim, iql_q_hiddens, iql_layernorm, gym_name ,load_state_dic=False)
    critic.load_state_dict(checkpoints['critic'])
    critic.to(device)
    critic.eval()
    return critic