import os

import torch
import torch.nn as nn
from lampe.inference import NPE
from lampe.inference import NPELoss

from .base import Model
from .base import ModelFactory


class NPEFactory(ModelFactory):
    def __init__(self, config, benchmark, simulation_budget):
        config_run = config.copy()
        for idx in range(len(config["simulation_budgets"])):
            if config["simulation_budgets"][idx] == simulation_budget:
                break
        config_run["train_batch_size"] = config["train_batch_size"][idx]
        config_run["weight_decay"] = config["weight_decay"][idx]

        super().__init__(config_run, benchmark, simulation_budget, NPEModel)

    def get_train_time(self, benchmark_time, epochs):
        return 2 * super().get_train_time(benchmark_time, epochs)


# class NPEWithEmbedding(nn.Module):
#     def __init__(self, npe, embedding):
#         super().__init__()
#         self.npe = npe
#         self.embedding = embedding

#     def forward(self, theta, x):
#         return self.npe(theta, self.embedding(x))

#     def sample(self, x, shape):
#         return self.npe.sample(self.embedding(x), shape)


class NPEWithEmbedding(nn.Module):
    def __init__(
        self,
        npe,
        embedding,
        normalize_observation_fct,
        unnormalize_observation_fct,
        normalize_parameters_fct,
        unnormalize_parameters_fct,
        normalization_log_jacobian,
    ):
        super().__init__()
        self.npe = npe
        self.embedding = embedding
        self.normalize_observation_fct = normalize_observation_fct
        self.unnormalize_observation_fct = unnormalize_observation_fct
        self.normalize_parameters_fct = normalize_parameters_fct
        self.unnormalize_parameters_fct = unnormalize_parameters_fct
        self.normalization_log_jacobian = normalization_log_jacobian

    def forward(self, theta, x):
        x = self.normalize_observation_fct(x)
        theta = self.normalize_parameters_fct(theta)
        return self.npe(theta, self.embedding(x)) + self.normalization_log_jacobian

    def sample(self, x, shape):
        x = self.normalize_observation_fct(x)
        model_output = self.npe.flow(self.embedding(x)).sample(shape)
        model_output = self.unnormalize_parameters_fct(model_output)
        return model_output


class NPEModel(Model):
    def __init__(self, benchmark, model_path, config, normalization_constants):
        super().__init__(normalization_constants)
        self.observable_shape = benchmark.get_observable_shape()
        self.embedding_dim = benchmark.get_embedding_dim()
        self.parameter_dim = benchmark.get_parameter_dim()
        self.device = benchmark.get_device()

        self.model_path = model_path

        self.prior = benchmark.get_prior()
        self.config = config

        embedding_build = benchmark.get_embedding_build()
        self.embedding = embedding_build(self.embedding_dim, self.observable_shape).to(
            self.device
        )

        flow_build, flow_kwargs = benchmark.get_flow_build()
        self.flow = NPE(
            self.parameter_dim, self.embedding_dim, build=flow_build, **flow_kwargs
        ).to(self.device)
        self.model = NPEWithEmbedding(
            self.flow,
            self.embedding,
            self.normalize_observation,
            self.unnormalize_observation,
            self.normalize_parameters,
            self.unnormalize_parameters,
            self.get_normalization_log_jacobian(),
        )

    @classmethod
    def is_trained(cls, model_path):
        return os.path.exists(
            os.path.join(model_path, "embedding.pt")
        ) and os.path.exists(os.path.join(model_path, "flow.pt"))

    def get_loss_fct(self, config):
        return NPELoss

    def log_prob(self, theta, x):
        x = x.to(self.device)
        theta = theta.to(self.device)
        return self.model(theta, x)

    def get_posterior_fct(self):
        def get_posterior(x):
            class Posterior:
                def __init__(self, sampling_fct, log_prob_fct):
                    self.sample = sampling_fct
                    self.log_prob = log_prob_fct

            return Posterior(
                lambda shape: self.model.sample(
                    x.to(self.device), shape).cpu(),
                lambda theta: self.model(
                    theta.to(self.device), x.to(self.device)
                ).cpu(),
            )

        return get_posterior

    def __call__(self, theta, x):
        return self.log_prob(theta, x)

    def sampling_enabled(self):
        return True

    def save(self):
        torch.save(
            self.embedding.state_dict(), os.path.join(self.model_path, "embedding.pt")
        )
        torch.save(self.flow.state_dict(), os.path.join(
            self.model_path, "flow.pt"))

    def load(self):
        self.embedding.load_state_dict(
            torch.load(os.path.join(self.model_path, "embedding.pt"), map_location=self.device)
        )
        self.flow.load_state_dict(torch.load(
            os.path.join(self.model_path, "flow.pt"), map_location=self.device))

    def train(self):
        self.model.train()

    def eval(self):
        self.model.eval()

    def sample(self, x, shape):
        x = x.to(self.device)
        return self.model.sample(x, shape)
    
    def is_ensemble(self):
        if "ensemble" in self.config:
            return self.config["ensemble"]
        return False