import os.path
from typing import Tuple, Union

import torch
from torch import nn
import wandb

from ddlm.modeling.diffusion import (
    DiffusionTransformer,
)
from ddlm.modeling.diffusion_simplified import (
    SimplifiedDiffusionTransformer,
)
from ddlm.time.time_wrapping import TimeWrapping


def get_model_from_run(
    config, run, step: int = 1000000, artifact: bool = False
) -> Tuple[nn.Module, Union[None, nn.Module]]:
    if artifact:
        filename = f"artifacts/{run.id}_model_step_{step:07}:v0/pytorch_model.bin"
        if not os.path.exists(filename):
            for a in run.logged_artifacts():
                if f"{step:07}" in a.name:
                    filename = a.get_path("pytorch_model.bin").download()
                    try:
                        tw_model_filename = a.get_path("pytorch_model_1.bin").download()
                        print("loaded tw_model")
                    except:
                        print("Can't load tw model")
    else:
        filename = f"model_step_{step:07}.pth"
        tw_model_filename = f"tw_model_step_{step}.pth"
        f = run.file(filename)
        if f.size == 0:
            print("Loading model from state...")
            filename = f"training_state_{step:07}/pytorch_model.bin"
            f = run.file(f"training_state_{step:07}/pytorch_model.bin")

        f.download(replace=True)
    state_dict = torch.load(filename)

    model = DiffusionTransformer(config=config)
    model.load_state_dict(state_dict=state_dict)

    try:
        if not artifact:
            run.file(tw_model_filename).download(replace=True)
            tw_state_dict = torch.load(f"tw_model_step_{step}.pth")
        else:
            tw_state_dict = torch.load(tw_model_filename)
        tw_model = TimeWrapping(config=config)
        tw_model.load_state_dict(state_dict=tw_state_dict)
    except:
        print("Using model without time_wrapping")
        tw_model = None
    return model, tw_model
