import torch
import torch.utils.data
from torch import nn, optim
from scipy.stats import gaussian_kde

import numpy as np
import random

from tqdm import tqdm


def fit(
    traindata,
    testdata=None,
    N=1,
    beta=1,
    batch_size=128,
    epochs=50,
    logpx=None,
    cuda=True,
    seed=0,
    log_interval=10,
    learning_rate=1e-2,
    prior_sdy=0.5,
    update_sdy=True,
    preload=False,
    debug=False,
    training_hyperparameters=None,
    verbose=False,
):
    """

    :param traindata: Traning data of causal pair X, Y.
    :param testdata: Testing data of causal pair X, Y.
    :param N: The number of latent intermediated variables
    :param beta: The beta parameters in term of beta-VAE for controlling the size of KL-divergence in ELBO
    :param batch_size: The batch size.
    :param epochs: The training epochs
    :param logpx: The average log-likelihood for p(x) at each sample.
    :param cuda: Whether use GPU.
    :param seed: The random seed
    :param log_interval: The option of verbose that output the training detail at each intervals.
    :param learning_rate: The learning rate.
    :param prior_sdy: The initialization of the standard error at noise distribution.
    :param update_sdy: Whether update the noise distribution by using gradient decent.
    :param preload: If preload=True, the traindata and testdata must be the object of the torch.utils.data.DataLoader.
    :param warming_up: Using a warming up strategy for the beta.
    :param verbose: print output
    :param debug: If debug=True, the output will return the model object for debug.
    :return:

    """
    torch.set_num_threads(1)
    try:
        if torch.backends.mps.is_available():
            device = torch.device("mps")
            kwargs = {
                "num_workers": 0,
                "pin_memory": False,
            }  # Force single-process loading for MPS
        elif cuda and torch.cuda.is_available():
            device = torch.device("cuda")
            kwargs = {"num_workers": 0, "pin_memory": False}
        else:
            device = torch.device("cpu")
            kwargs = {"num_workers": 1, "pin_memory": False}

        if verbose:
            print(f"Using device: {device}")
    except Exception as e:
        print(f"Error initializing device, falling back to CPU. Error: {str(e)}")
        device = torch.device("cpu")
        kwargs = {"num_workers": 1, "pin_memory": False}

    model = CANM(N).to(device)

    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    if logpx is None:
        pde = gaussian_kde(traindata[:, 0])
        logpx = pde.logpdf(traindata[:, 0]).mean()
    if preload:
        train_loader = traindata
        test_loader = testdata

    if not preload:
        traindata = torch.from_numpy(np.copy(traindata)).float().to(device)
        train_loader = torch.utils.data.DataLoader(
            traindata, batch_size=batch_size, shuffle=True, **kwargs
        )
    if testdata is not None and not preload:
        testdata = torch.from_numpy(np.copy(testdata)).float().to(device)
        test_loader = torch.utils.data.DataLoader(
            testdata, batch_size=batch_size, shuffle=True, **kwargs
        )

    if update_sdy:
        sdy = torch.tensor(
            [prior_sdy], device=device, dtype=torch.float, requires_grad=True
        )
        optimizer = optim.Adam(
            [{"params": model.parameters()}, {"params": sdy}], lr=learning_rate
        )  # Use Adam
    else:
        sdy = torch.tensor(
            [prior_sdy], device=device, dtype=torch.float, requires_grad=False
        )
        optimizer = optim.Adam([{"params": model.parameters()}], lr=learning_rate)

    train_scores = []
    test_scores = []
    train_logging_infos = []
    for epoch in tqdm(range(1, epochs + 1)):
        # train(epoch)
        model.train()
        train_loss = 0
        if training_hyperparameters["type"] == "linear":
            wu_beta = beta * (epoch / epochs)
        elif training_hyperparameters["type"] == "cyclical":
            tau = ((epoch - 1) % (epochs / training_hyperparameters["M"])) / (
                epochs / training_hyperparameters["M"]
            )
            if tau <= training_hyperparameters["R"]:
                wu_beta = tau / training_hyperparameters["R"]
            else:
                wu_beta = beta
        elif training_hyperparameters["type"] == "constant":
            wu_beta = beta
        else:
            raise ValueError(
                f"Unsupported training hyperparameter type: {training_hyperparameters['type']}"
            )

        epoch_logging_info = {}
        num_batches = 0
        for batch_idx, data in enumerate(train_loader):

            data = data.to(device)
            optimizer.zero_grad()
            # x = data[:, 0]
            y = data[:, 1].view(-1, 1)
            yhat, mu, logvar = model(data)

            loss, loss_details = loss_function(y, yhat, mu, logvar, sdy, wu_beta)
            loss -= logpx * len(data)
            loss_details["loss"] = loss.item()
            loss_details["logpx"] = logpx * len(data)
            loss_details["len(data)"] = len(data)
            loss_details["epoch"] = epoch
            loss_details["beta"] = wu_beta

            # Accumulate logging info for averaging
            for key, value in loss_details.items():
                if value is None:
                    continue
                if key not in epoch_logging_info:
                    epoch_logging_info[key] = 0.0
                epoch_logging_info[key] += value

            # Store loss for this batch
            num_batches += 1

            loss.backward()

            train_loss += loss.item()
            optimizer.step()
            if (
                update_sdy and sdy < 0.01
            ):  # Ensuring the sdy larger than 0.01 to avoid the NAN loss.
                sdy = sdy + 0.01
            if verbose and batch_idx % log_interval == 0:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                        epoch,
                        batch_idx * len(data),
                        len(train_loader.dataset),
                        100.0 * batch_idx / len(train_loader),
                        loss.item() / len(data),
                    )
                )

        train_loss /= len(train_loader.dataset)

        train_scores.append(-train_loss)

        # Average the logging info over all batches in the epoch
        for key in epoch_logging_info:
            epoch_logging_info[key] /= num_batches

        train_logging_infos.append(epoch_logging_info)

        if verbose:
            print(
                "====> Epoch: {} Average loss: {:.4f}".format(epoch, train_loss),
                epoch_logging_info,
            )

        # test(epoch)
        if testdata is not None:
            model.eval()
            test_loss = 0
            with torch.no_grad():
                for i, data in enumerate(test_loader):
                    data = data.to(device)
                    yhat, mu, logvar = model(data)
                    # x = data[:, 0]
                    y = data[:, 1].view(-1, 1)
                    test_loss += loss_function(y, yhat, mu, logvar, sdy, wu_beta)[
                        0
                    ].item() - logpx * len(data)

                test_loss /= len(test_loader.dataset)
                if verbose:
                    print("====> Test set loss: {:.4f}".format(test_loss))
            test_scores.append(-test_loss)

    if testdata is not None:
        output = {
            "train_likelihood": -float(train_loss),
            "test_likelihood": -float(test_loss),
            "train_score": train_scores,
            "test_score": [float(score) for score in test_scores],
            "sdy": sdy.cpu().detach().numpy(),
        }
    else:
        output = {
            "train_likelihood": -float(train_loss),
            "train_score": train_scores,
            "sdy": sdy.cpu().detach().numpy(),
        }
    if debug:
        output["model"] = model

    output["logpx"] = logpx
    output["train_logging_infos"] = train_logging_infos

    return output


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(y, yhat, mu, logvar, sdy, beta):
    # BCE = F.binary_cross_entropy(recon_x, x.view(-1, 2), size_average=False)
    # D=D.add_([xhat,yhat])
    N = y - yhat

    if sdy.item() <= 0:
        sdy = -sdy + 0.01

    n = torch.distributions.Normal(0, sdy)

    # negate because we want to maximize the likelihood
    BCE = -torch.sum(n.log_prob(N))  # Compute the log-likelihood of noise distribution.

    # BCE=F.mse_loss(torch.cat((xhat,yhat),1),D)
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)

    # do not negate because we want to minimize the KLD
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) * beta
    # KLD=0
    return BCE + KLD, {
        "BCE": BCE.detach().cpu().item(),
        "KLD": KLD.detach().cpu().item(),
    }


class CANM(nn.Module):
    def __init__(self, N):  # N: The number of latent variables
        super(CANM, self).__init__()
        self.N = N
        # encoder 输入只有y
        self.fc1 = nn.Linear(2, 20)
        # 均值网络ohxb
        self.fc21 = nn.Linear(20, 12)
        self.fc22 = nn.Linear(12, 7)
        self.fc23 = nn.Linear(7, N)

        # 方差网络
        self.fc31 = nn.Linear(20, 12)
        self.fc32 = nn.Linear(12, 7)
        self.fc33 = nn.Linear(7, N)

        # decoder
        self.fc4 = nn.Linear(1 + N, 10)  # yhat=f(x,z)
        # self.fc4 = nn.Linear(1, 10)  # yhat=f(x,z)

        self.fc5 = nn.Linear(10, 7)
        self.fc6 = nn.Linear(7, 5)
        self.fc7 = nn.Linear(5, 1)  # yhat

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def encode(self, xy):
        xy = xy.view(-1, 2)
        h1 = self.relu(self.fc1(xy))

        h21 = self.relu(self.fc21(h1))
        h22 = self.relu(self.fc22(h21))
        mu = self.fc23(h22)

        h31 = self.relu(self.fc31(h1))
        h32 = self.relu(self.fc32(h31))
        logvar = self.fc33(h32)

        return mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)  # e^(0.5*logvar)=e^(logstd)=std
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, x, z):
        x = x.view(-1, 1)
        z = z.view(-1, self.N)
        h4 = self.relu(self.fc4(torch.cat((x, z), 1)))
        h5 = self.relu(self.fc5(h4))
        h6 = self.relu(self.fc6(h5))
        yhat = self.fc7(h6)
        return yhat

    def forward(self, data):
        data = data.view(-1, 2)
        x = data[:, 0]
        y = data[:, 1]

        mu, logvar = self.encode(data)
        z = self.reparameterize(mu, logvar)
        yhat = self.decode(x, z)
        return yhat, mu, logvar


if __name__ == "__main__":
    np.random.seed(0)
    X = np.random.normal(0, 1, 10000)
    Z = np.power(X, 3) + 0.5 * np.random.normal(0, 1, 10000)
    Y = np.tanh(2 * Z) + 0.5 * np.random.normal(0, 1, 10000)
    traindata1 = np.array([X, Y]).transpose()
    result1 = fit(traindata1, verbose=True)
    traindata2 = np.array([Y, X]).transpose()
    result2 = fit(traindata2, N=1, verbose=True)
    if np.max(result1["train_score"]) > np.max(result2["train_score"]):
        print("X->Y")
    else:
        print("Y->X")
