import numpy as np
import torch
import torch.autograd as autograd
import torch.nn as nn
from scipy import stats
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# adapted from causal-learn pnl

# 1) pick device once
device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Running on {device}")


class PairDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        # convert to torch.Tensor once so DataLoader yields tensors
        self.data = torch.from_numpy(data.astype(np.float32)).to(device)

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, idx):
        return self.data[idx]


class MLP(nn.Module):
    def __init__(self, n_inputs, n_outputs, n_layers=1, n_units=100):
        super().__init__()
        layers = [nn.Linear(n_inputs, n_units)]
        for _ in range(n_layers):
            layers += [nn.ReLU(), nn.Linear(n_units, n_units)]
        layers += [nn.ReLU(), nn.Linear(n_units, n_outputs)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


class PNL:
    def __init__(self, epochs=3000, device=device):
        self.epochs = epochs
        self.device = device
        # reproducibility
        # torch.manual_seed(0)
        if self.device.type == "cuda":
            torch.cuda.manual_seed(0)

    def nica_mnd(self, X, TotalEpoch):
        # X: numpy array, shape (n,2)
        dataset = PairDataset(X)
        loader = DataLoader(dataset, batch_size=1024, drop_last=False, shuffle=True)

        # build nets and move them
        G1 = MLP(1, 1, n_layers=3, n_units=12).to(self.device)
        G2 = MLP(1, 1, n_layers=1, n_units=12).to(self.device)
        optimizer = torch.optim.Adam(
            list(G1.parameters()) + list(G2.parameters()), lr=1e-4, betas=(0.9, 0.99)
        )

        for _ in tqdm(range(TotalEpoch), desc="PNL training"):
            optimizer.zero_grad()
            for batch in loader:
                # move each batch
                # batch = batch.to(self.device)
                x1 = batch[:, [0]].requires_grad_()
                x2 = batch[:, [1]].requires_grad_()

                e = G2(x2) - G1(x1)
                loss_pdf = 0.5 * (e**2).sum()
                jacobian = autograd.grad(
                    outputs=e,
                    inputs=x2,
                    grad_outputs=torch.ones_like(e),
                    create_graph=True,
                    retain_graph=True,
                )[0]
                loss_jacob = -(torch.log(jacobian.abs() + 1e-16)).sum()

                (loss_pdf + loss_jacob).backward()
                optimizer.step()

        # after training, compute on the full dataset
        Xall = torch.from_numpy(X.astype(np.float32)).to(self.device)
        x1_all = Xall[:, [0]]
        x2_all = Xall[:, [1]]
        e_all = G2(x2_all) - G1(x1_all)
        return x1_all, e_all

    def cause_or_effect(self, data_x, data_y):
        # seed again for each call
        torch.manual_seed(0)
        data = np.concatenate((data_x, data_y), axis=1)

        # forward direction
        y1, y2 = self.nica_mnd(data, self.epochs)
        p_fwd = stats.ttest_ind(
            y1.cpu().detach().numpy(), y2.cpu().detach().numpy()
        ).pvalue

        # backward direction
        swapped = data[:, [1, 0]]
        y1b, y2b = self.nica_mnd(swapped, self.epochs)
        p_bwd = stats.ttest_ind(
            y1b.cpu().detach().numpy(), y2b.cpu().detach().numpy()
        ).pvalue

        return np.round(p_fwd, 3), np.round(p_bwd, 3)
