from functools import partial

import torch
import torch.nn as nn
import os
from abc import abstractmethod
from pathlib import Path
from geomloss import SamplesLoss

from margflow.utils.math_utils import remove_neg_inf
from margflow.utils.training_utils import batched_evaluation


class AbstractModel(nn.Module):
    def __init__(
        self,
        model_name: str,  # name of the model
        x_dim: int,  # dimensionality of data/density
        device: str = "cuda",  # device on which model is initialized and trained
        signature: str = None,  # identifier of dataset and training hyperparams
        script_path: Path = None,  # path to script directory
        dtype: torch.dtype = torch.float32,
    ):  # dtype of tensors
        super(AbstractModel, self).__init__()

        self.model_name = model_name
        self.x_dim = x_dim
        self.device = device
        self.script_path = script_path
        self.signature = signature
        self.dtype = dtype

        self.trainable_params = {}  # to be specified in the child class

    @abstractmethod
    def sample(self, n_samples, context=None, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def log_prob(self, x, context=None, **kwargs):
        raise NotImplementedError()

    def set_model_signature(self):
        if self.script_path is not None and self.signature is not None:
            self.model_directory = self.script_path / "trained_models" / self.model_name
            self.model_signature = (
                f"{self.model_name}_np{self.count_parameters()}_{self.signature}"
            )
            self.model_path = self.model_directory / self.model_signature
        else:
            print(
                "Since script_path and/or signature is None, the model will not be saved to disk"
            )

    @torch.no_grad()
    def evaluate_metrics(
        self, metrics, n_samples, val_samples, dataset, n_mixtures=1024, batch_size=10_000
    ):
        # generate new samples from the model
        try:
            _, model_samples, model_log_prob_samples, _ = self.sample_and_log_prob(
                n_samples=n_samples, n_mixtures=n_mixtures
            )
            # model_log_prob_val = self.log_prob(x=val_samples, n_mixtures=n_mixtures)
            log_prob_function = partial(self.log_prob, n_mixtures=n_mixtures)
            model_log_prob_val = batched_evaluation(
                data=val_samples, batch_size=batch_size, function=log_prob_function
            )
            model_log_prob_val = torch.from_numpy(model_log_prob_val).to(val_samples.device)
        except:
            model_samples, model_log_prob_samples = self.sample_and_log_prob(n_samples=n_samples)
            model_log_prob_val = self.log_prob(x=val_samples)

        metrics_dict = {}
        for metric in metrics:
            if metric == "mse_logp":
                gt_log_prob_val = dataset.log_prob(val_samples)
                gt_log_prob_val = remove_neg_inf(gt_log_prob_val)
                mse_logp = ((gt_log_prob_val - model_log_prob_val) ** 2).mean()
                metrics_dict[metric] = mse_logp.item()
            elif metric == "kl_rev":  # zero-forcing / mode-seeking
                gt_log_prob_samples = dataset.log_prob(model_samples)
                gt_log_prob_samples = remove_neg_inf(gt_log_prob_samples)
                kl_rev = torch.mean(model_log_prob_samples - gt_log_prob_samples)
                metrics_dict[metric] = kl_rev.item()
            elif metric == "kl_for":  # mass-covering / mean-seeking
                gt_log_prob = dataset.log_prob(val_samples)
                gt_log_prob = remove_neg_inf(gt_log_prob)
                kl_for = torch.mean(gt_log_prob - model_log_prob_val)
                metrics_dict[metric] = kl_for.item()
            elif metric == "log_lik":
                log_lik = torch.mean(model_log_prob_val)
                metrics_dict[metric] = log_lik.item()
            elif metric in ["sinkhorn", "energy", "gaussian", "laplacian"]:
                max_samples = 2048
                loss_function = SamplesLoss(loss=metric)
                loss = loss_function(model_samples[:max_samples], val_samples[:max_samples])
                metrics_dict[metric] = loss.item()
            else:
                raise ValueError(f"metric not recognized")

        return metrics_dict

    def existing_trained_model(self, overwrite):
        if self.script_path is not None:
            if not os.path.exists(self.model_directory):
                os.makedirs(self.model_directory)

            model_exists = True if os.path.isfile(self.model_path) else False

            return True if model_exists and not overwrite else False
        else:
            return False

    def save_trained_model(self):
        if self.script_path is not None:
            trained_params = {}
            for param_name, param in self.trainable_params.items():
                if isinstance(param, nn.Parameter):
                    trained_params[param_name] = param
                elif isinstance(param, nn.Module):
                    trained_params[param_name] = param.state_dict()
                else:
                    raise TypeError(f"Unsupported type {type(param)} in trainable_params.")

            torch.save(trained_params, self.model_path)
        else:
            pass

    def load_trained_model(self):
        if self.script_path is not None:
            assert (
                self.script_path is not None
            ), "script_path must be specified when initializing the model"
            trained_model = torch.load(self.model_path, weights_only=True)
            for param_name, param in self.trainable_params.items():
                if param_name not in trained_model:
                    raise KeyError(f"Missing {param_name} in saved file.")
                if isinstance(param, nn.Parameter):
                    with torch.no_grad():
                        # self.trainable_params[param_name] = trained_model[param]
                        param.copy_(trained_model[param_name])
                elif isinstance(param, nn.Module):
                    param.load_state_dict(trained_model[param_name])
                else:
                    raise TypeError(f"Unsupported type {type(param)} in trainable_params.")
        else:
            pass

    def count_parameters(self):
        total = 0
        for item in self.trainable_params.values():
            if isinstance(item, nn.Parameter):
                total += item.numel()
            elif isinstance(item, nn.Module):
                total += sum(p.numel() for p in item.parameters() if p.requires_grad)
            else:
                raise TypeError(f"Unsupported type {type(item)} in trainable_params.")
        return total
