import torch
import torch.nn.functional as F
import mlflow
import mlflow.pytorch as mlpt
from denflow.utils import define_model, load_model_runid
from denflow.train_denoisers import GENERAL_DENOISER


class CONCATENATED_DENOISER(torch.nn.Module):

    def __init__(self, loss_denoising, class_denoiser, device, args):
        super().__init__()
        self.d = args.dim_image
        self.num_channels = args.num_channels
        self.device = device
        self.args = args
        self.loss_denoising = loss_denoising
        self.class_denoiser = class_denoiser
        self.model_fold = args.model_fold
        self.run_ids = self.args.run_ids  # expect a list of run_ids
        self.models_dict = {}  # key = (t_min, t_max), value = model on device
        for run_id in self.run_ids:
            client = mlflow.MlflowClient()
            path = f"runs:/{run_id}/{self.model_fold}"
            model = mlpt.load_model(path)
            run_data_dict = client.get_run(run_id).data.to_dictionary()
            name_with_params = run_data_dict['tags']['mlflow.runName'].replace(
                ",", "").split()
            if 't_min' in name_with_params:
                t_min = float(name_with_params[name_with_params.index('t_min') + 1])
            else:
                t_min = 0.0
            if 't_max' in name_with_params:
                t_max = float(name_with_params[name_with_params.index('t_max') + 1])
            else:
                t_max = 1
            model.eval()
            model.to(device)
            self.models_dict[(t_min, t_max)] = GENERAL_DENOISER(
                model, self.loss_denoising, self.class_denoiser, self.device, self.args)
        self.valid_intervals = list(self.models_dict.keys())
        self.time_start = 0.
        self.time_end = 1.0

    def get_interval(self, t):
        valids = []
        for (t_min, t_max) in self.valid_intervals:
            if t.item() >= t_min and t.item() <= t_max:
                valids.append((t_min, t_max))
        if len(valids) == 0:
            ValueError("model does not have a valid interval for this timestep")
        elif len(valids) == 1:
            return valids[0]
        else:
            valids_sorted = sorted(valids, key=lambda inter: inter[1]-inter[0])
            return valids_sorted[0]

    def get_denoiser(self, xt, t):
        """
        xt: [B,C,H,W]
        t : [B]
        """
        if torch.is_tensor(t):
            assert torch.allclose(t, t[0].expand_as(t))
        t_ = t[0]
        interval = self.get_interval(t_)
        current_model = self.models_dict[interval]
        return current_model.get_denoiser(xt, t)

    def get_velocity(self, xt, t):
        """
        xt: [B,C,H,W]
        t : [B]
        """
        if torch.is_tensor(t):
            assert torch.allclose(t, t[0].expand_as(t))
        t_ = t[0]
        interval = self.get_interval(t_)
        current_model = self.models_dict[interval]
        return current_model.get_velocity(xt, t)
