import pyrootutils
import os
from typing import Union
from tqdm import tqdm
import pickle
import torch

from schnetpack.data import AtomsDataModule
from schnetpack.diffusion.utils import sample_R
from schnetpack import properties


def loss_per_step(
    data: AtomsDataModule,
    models_paths: list,
    noise_schedules: list,
    T: Union[int, list] = 1000,
    recompute=False,
    pred_key: str = "eps_pred",
    target_key: str = "eps",
):

    if type(T) == int:
        T = range(T)

    for i, p in enumerate(models_paths):
        m = torch.load(os.path.join(p, "best_model"), map_location="cpu")
        data.train_transforms[1].noise_schedule = noise_schedules[i]
        if not os.path.exists(os.path.join(p, "loss_per_step.pkl")) or recompute:
            loss_per_step = {}
            for i in tqdm(T):
                data.train_transforms[1].t_train = torch.tensor([i])
                batch = next(iter(data.train_dataloader()))
                diff_batch = m(batch)
                loss_per_step[i] = (
                    torch.nn.functional.mse_loss(
                        diff_batch[pred_key], batch[target_key]
                    )
                    .detach()
                    .item()
                )
            data.train_transforms[1].t_train = None
            with open(os.path.join(p, "loss_per_step.pkl"), "wb") as f:
                pickle.dump(loss_per_step, f)

    loss_per_step = {}
    for p in models_paths:
        with open(os.path.join(p, "loss_per_step.pkl"), "rb") as f:
            loss_per_step[p.split("/")[-1]] = pickle.load(f)
    return loss_per_step


def generate_samples(
    models,
    noise_schedules,
    models_paths,
    inputs,
    cutoff=5.0,
    T=1000,
    start=None,
    random=True,
    use_forces=True,
    save_progress=True,
    progress_stride=1,
):
    """
    wrappe around sample_R to generate samples for multiple models and save them externally
    """
    for id, model in models.items():
        noise_sch = noise_schedules[id]
        samples, progress = sample_R(
            inputs[id].copy(),
            model,
            noise_sch,
            cutoff=cutoff,
            T=T,
            start=start,
            random=random,
            use_forces=use_forces,
            save_progress=save_progress,
            progress_stride=progress_stride,
        )

        # find model path
        for p in models_paths:
            if id in p:
                model_path = p
                break
        if model_path is None:
            raise ValueError(f"Model {id} not found in models_paths")
        if not os.path.exists(os.path.join(model_path, "samples")):
            os.mkdir(os.path.join(model_path, "samples"))
        file_idx = 0
        while os.path.exists(
            os.path.join(model_path, "samples", f"samples_{file_idx}.pkl")
        ):
            file_idx += 1

        tmp = {
            "samples": {
                0: {
                    properties.R: samples,
                    "progress_R": progress,
                    properties.Z: inputs[id][properties.Z],
                    properties.idx_m: inputs[id][properties.idx_m],
                    properties.idx: inputs[id][properties.idx],
                    properties.n_atoms: inputs[id][properties.n_atoms],
                }
            },
            "model_id": id,
            "start": start,
            "T": T,
            "random": random,
            "progress_stride": progress_stride,
        }

        with open(
            os.path.join(model_path, "samples", f"samples_{file_idx}.pkl"), "wb"
        ) as f:
            pickle.dump(tmp, f)
