import os
import torch
from model import SEDD,SEDDWot,SEDDWotSM
import utils
from model.ema import ExponentialMovingAverage
import graph_lib
import noise_lib

from omegaconf import OmegaConf

import json
from omegaconf import OmegaConf

# def load_model(dir, device):
#     with open(f'{dir}/config.json', 'r') as f:
#         config = json.load(f)
#     config = OmegaConf.create(config)
#     graph = graph_lib.get_graph(config, device)
#     noise = noise_lib.get_noise(config).to(device)
#     if hasattr(config.model, 'remove_time_condition') == False or config.model.remove_time_condition == False:
#         model = SEDD(config).to(device)
#     elif config.model.add_softmax == False:
#         model = SEDDWot(config).to(device)
#     else:
#         model = SEDDWotSM(config).to(device)
#     model.load_state_dict(torch.load(f'{dir}/pytorch_model.bin',map_location=torch.device('cpu')))
#     return model, graph, noise

def load_model_RADD(dir, device):
    score_model = SEDDWotSM.from_pretrained(dir).to(device)
    graph = graph_lib.get_graph(score_model.config, device)
    noise = noise_lib.get_noise(score_model.config).to(device)
    return score_model, graph, noise

def load_model_hf(dir, device):
    score_model = SEDD.from_pretrained(dir).to(device)
    graph = graph_lib.get_graph(score_model.config, device)
    noise = noise_lib.get_noise(score_model.config).to(device)
    return score_model, graph, noise


def load_model_local(root_dir, device):
    cfg = utils.load_hydra_config_from_run(root_dir)
    graph = graph_lib.get_graph(cfg, device)
    noise = noise_lib.get_noise(cfg).to(device)
    if cfg.model.remove_time_condition == False:
        score_model = SEDD(cfg).to(device)
    elif cfg.model.add_softmax == False:
        score_model = SEDDWot(cfg).to(device)
    else:
        score_model = SEDDWotSM(cfg).to(device)


    ema = ExponentialMovingAverage(score_model.parameters(), decay=cfg.training.ema)

    ckpt_dir = os.path.join(root_dir, "checkpoints-meta", "checkpoint.pth")
    loaded_state = torch.load(ckpt_dir, map_location=device)

    score_model.load_state_dict(loaded_state['model'])
    ema.load_state_dict(loaded_state['ema'])

    ema.store(score_model.parameters())
    ema.copy_to(score_model.parameters())
    return score_model, graph, noise


def load_model(root_dir, device):
    try:
        return load_model_local(root_dir, device)
    except:
        return load_model_hf(root_dir, device)
