import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy
from tqdm import tqdm
import torch.utils.data
from torch import nn, optim
from scipy.stats import gaussian_kde
from sklearn.model_selection import train_test_split

from CANM.inst.python.CANM import fit


def CANM_loss(
    data,
    epochs=50,
    batch_size=128,
    prior_sdy=0.5,
    minN=1,
    maxN=5,
    seed=0,
    update_sdy=True,
    debug=False,
    find_best_N=True,
    depth=1,
    training_hyperparameters={
        "type": "constant",
    },
    **kwargs
):
    pde = gaussian_kde(data.T[:, 0])
    logpx = pde.logpdf(data.T[:, 0]).mean()
    train, test = train_test_split(
        data, test_size=0.2, random_state=seed
    )

    if find_best_N:
        bestN = 1
        bestScore = -numpy.inf
        if debug:
            print("Finding best N")
        results = [
            {
                **fit(
                    traindata=train,
                    testdata=test,
                    N=N,
                    logpx=logpx,
                    epochs=epochs,
                    batch_size=batch_size,
                    prior_sdy=prior_sdy,
                    seed=seed,
                    update_sdy=update_sdy,
                    training_hyperparameters=training_hyperparameters,
                    **kwargs,
                ),
                "N": N,
            }
            for N in tqdm(range(minN, maxN + 1))
        ]

        max_result = max(results, key=lambda x: x["test_likelihood"])

        bestN = max_result["N"]
    else:
        if depth is None:
            raise ValueError("depth must be specified when find_best_N is False")
        bestN = depth - 1

    all_data_results = fit(
        traindata=data,
        N=bestN,
        logpx=logpx,
        epochs=epochs,
        batch_size=batch_size,
        prior_sdy=prior_sdy,
        seed=seed,
        update_sdy=update_sdy,
        training_hyperparameters=training_hyperparameters,
        **kwargs,
    )

    all_data_results["bestN"] = bestN
    all_data_results["logpx"] = logpx

    if find_best_N:
        all_data_results["different_N_results"] = results
    return all_data_results
