import torch
import torch.nn.functional as F
from denflow.utils import define_model
from denflow.train_denoisers import GENERAL_DENOISER


class TEN_MODELS_DENOISER(torch.nn.Module):

    def __init__(self, model, loss_denoising, class_denoiser, device, args):
        super().__init__()
        self.d = args.dim_image
        self.num_channels = args.num_channels
        self.device = device
        self.args = args
        self.run_ids = {}
        self.loss_denoising = loss_denoising
        self.class_denoiser = class_denoiser
        self.run_ids["FM"],  self.run_ids["den"] = {}, {}
        self.run_ids["den"]["NN"] = ['a06e0e5c45fd4e89b9df9b5067cf5bac', '0764693da2c64f65b419856af6555e24', '87e219758ff648f0bf92241ba309ba1d', '52cd162fb3cf4bfebff69fe078ccfd82', '61f08da8575f4146bd56f716e89f0576',
                                     '700aa94d7227410db06cdc5f8f5f2347', '7b806d19eaa943e1b84652ca96cdae74', 'fc645157462645e89eac53d7e0372acb', 'd2b5cb2bc833409985c44c3dae1efe41', '72e6f4e5c57c4eb78db32cce09e750d9']

        self.run_ids["den"]["shifted"] = ['3585b618aec74f4da474fd15f1786c86', 'e5ed31198bb44621b376c08db0ce8e65', '439873b099e34eadb299f1ad3dbefa45', '114f9e045baf4d2596f550ec4ebd4a01', '1bbf3182f4db4fa68c8e1bdd0a72d333',
                                          '2bd7decc64c24195b01adbc284cc5f4f', 'd5ab20ef63744c9794e46be41258df32', '49fa75ac7fe44db5a8c6af4fb3ba4445', '868e870fe4d24f5d899bb4558c95775d', 'cdd7e564ffb24fbcafdea527da70fafd']
        self.time_start = 0.
        self.time_end = 1.0

        self.load_models()

    def load_models(self):
        self.list_model_den = []
        for i in range(10):
            print('i', i)
            run_id = self.run_ids[self.loss_denoising][self.class_denoiser][i]
            # (model_den, state) = define_model(args)
            # model_den = load_model_runid(
            #     model_den, model_fold='denoiser_ema_final', run_id=run_id, device=device)

            (model_den, state) = define_model(self.args)
            if self.class_denoiser == "NN":
                model_path = self.args.root + \
                    'mlruns/576768551737851586/{}/artifacts/denoiser_ema_final/data/model.pth'.format(
                        run_id)  # to change manually
            elif self.class_denoiser == "shifted":
                model_path = self.args.root + \
                    'mlruns/590683623322889772/{}/artifacts/denoiser_ema_final/data/model.pth'.format(
                        run_id)  # to change manually
            obj = torch.load(model_path, map_location=self.device, weights_only=False)
            state = obj.state_dict()
            model_den.load_state_dict(state)
            model_den.to(self.device)

            model_den.eval().to(self.device)
            self.list_model_den.append(GENERAL_DENOISER(
                model_den, self.loss_denoising, self.class_denoiser, self.device, self.args))

    def get_denoiser(self, xt, t):
        """
        xt: [B,C,H,W]
        t : [B]
        """
        if torch.is_tensor(t):
            assert torch.allclose(t, t[0].expand_as(t))
        t_ = t[0]
        if t_ < 1:
            index = int(t_ * 10)
        else:
            index = 9
        current_model = self.list_model_den[index]

        return current_model.get_denoiser(xt, t)

    def get_velocity(self, xt, t):
        """
        xt: [B,C,H,W]
        t : [B]
        """
        if torch.is_tensor(t):
            assert torch.allclose(t, t[0].expand_as(t))
        t_ = t[0]
        if t_ < 1:
            index = int(t_ * 10)
        else:
            index = 9
        current_model = self.list_model_den[index]

        return current_model.get_velocity(xt, t)
