import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm


class ConditionalMLP(nn.Module):
    def __init__(self, in_dim, cond_dim, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim + cond_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x, cond):
        return self.net(torch.cat([x, cond], dim=-1))

class ConditionalAffineCoupling(nn.Module):
    def __init__(self, dim, cond_dim, hidden_dim, mask):
        super().__init__()
        self.dim = dim
        self.register_buffer("mask", mask)

        self.scale_net = ConditionalMLP(dim, cond_dim, hidden_dim, dim)
        self.shift_net = ConditionalMLP(dim, cond_dim, hidden_dim, dim)

    def forward(self, y, cond):
        y_masked = y * self.mask

        s = self.scale_net(y_masked, cond) * (1 - self.mask)
        t = self.shift_net(y_masked, cond) * (1 - self.mask)

        s = torch.tanh(s)  # stability

        y_out = y_masked + (1 - self.mask) * (y * torch.exp(s) + t)
        log_det = torch.sum(s, dim=-1)

        return y_out, log_det

    def inverse(self, y, cond):
        y_masked = y * self.mask

        s = self.scale_net(y_masked, cond) * (1 - self.mask)
        t = self.shift_net(y_masked, cond) * (1 - self.mask)

        s = torch.tanh(s)

        y_out = y_masked + (1 - self.mask) * ((y - t) * torch.exp(-s))
        log_det = -torch.sum(s, dim=-1)

        return y_out, log_det

class ConditionalRealNVP(nn.Module):
    def __init__(self, y_dim, cond_dim, hidden_dim=128, n_layers=6):
        super().__init__()
        self.layers = nn.ModuleList()

        for i in range(n_layers):
            mask = self._create_mask(y_dim, even=(i % 2 == 0))
            self.layers.append(
                ConditionalAffineCoupling(y_dim, cond_dim, hidden_dim, mask)
            )

        self._log_2pi = np.log(2*np.pi)

    def _create_mask(self, dim, even=True):
        mask = torch.zeros(dim)
        mask[::2] = 1 if even else 0
        mask[1::2] = 0 if even else 1
        return mask

    def log_prob(self, y, cond):
        log_det_sum = 0.0
        z = y

        for layer in reversed(self.layers):
            z, log_det = layer.inverse(z, cond)
            log_det_sum += log_det
    
        d = z.shape[-1]
        log_base = -0.5 * (z.pow(2).sum(dim=-1) + d * self._log_2pi)

        return log_base + log_det_sum

    def sample(self, cond, n_samples=1):
        B = cond.shape[0]
        z = self.base_dist.sample((B * n_samples,))
        cond_rep = cond.repeat_interleave(n_samples, dim=0)

        y = z
        for layer in self.layers:
            y, _ = layer(y, cond_rep)

        return y.view(B, n_samples, -1)

class ConditionalFlowEstimator:
    def __init__(self, y_dim, x_dim, device="cuda", hidden_dim=128, n_layers=6, lr=1e-3,  verbose=False):

        self.device = torch.device(device)
        self.model = ConditionalRealNVP(
            y_dim=y_dim,
            cond_dim=x_dim + 1,  # X + A
            hidden_dim=hidden_dim,
            n_layers=n_layers
        ).to(self.device)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.scheduler = None
        self.verbose = verbose

    def fit(
        self,
        X,
        A,
        Y,
        n_epochs=100,
        batch_size=256,
        grad_accum_steps=1,
    ):
        """
        Float32 training (no AMP, no GradScaler), GPU-memory friendly:
        - keep dataset on CPU
        - move only batches to GPU
        - support gradient accumulation
        """
        # Build conditioning features on CPU
        A = A.view(-1, 1)
        Z = torch.cat([X, A], dim=1)

        dataset = torch.utils.data.TensorDataset(Z, Y)
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=True,
            drop_last=False,
        )

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=n_epochs)

        self.model.train()

        epoch_bar = tqdm(
            range(n_epochs),
            desc="Training",
            dynamic_ncols=True,
            mininterval=0.1,
            leave=True,
            disable=not self.verbose,
        )

        for epoch in epoch_bar:
            total_loss = 0.0
            self.optimizer.zero_grad(set_to_none=True)

            for step, (Zb_cpu, Yb_cpu) in enumerate(loader):
                # Move only this batch to GPU
                Zb = Zb_cpu.to(self.device, non_blocking=True)
                Yb = Yb_cpu.to(self.device, non_blocking=True)

                loss = -self.model.log_prob(Yb, Zb).mean()
                loss = loss / grad_accum_steps  # for accumulation

                loss.backward()

                if (step + 1) % grad_accum_steps == 0 or (step + 1) == len(loader):
                    self.optimizer.step()
                    self.optimizer.zero_grad(set_to_none=True)

                total_loss += loss.item() * grad_accum_steps  # undo division for reporting

                # Help GC in tight-memory settings
                del Zb, Yb, Zb_cpu, Yb_cpu

            if self.scheduler is not None:
                self.scheduler.step()

            avg_nll = total_loss / len(loader)
            epoch_bar.set_postfix(NLL=f"{avg_nll:.4f}")


def estimate_P_matrix(
    model,
    X,
    A,
    Y_atoms,
    batch_size=128,
    atom_chunk_size=512,   # << key knob
    device="cuda",
):
    device = torch.device(device)
    model = model.to(device).eval()

    # Ensure A is shape (n,1) float on CPU
    A2 = A.unsqueeze(1) if A.ndim == 1 else A
    A2 = A2.to(dtype=torch.float32)
    Y_atoms = Y_atoms.to(device, non_blocking=True)


    n = X.shape[0]
    N_atoms = Y_atoms.shape[0]

    # Output on CPU
    out = torch.empty((n, N_atoms), device="cpu", dtype=torch.float32)

    with torch.inference_mode():
        for i in range(0, n, batch_size):
            xb_cpu = X[i : i + batch_size]
            ab_cpu = A2[i : i + batch_size]

            # Move this batch to GPU
            xb = xb_cpu.to(device, non_blocking=True)
            ab = ab_cpu.to(device, non_blocking=True)
            cond = torch.cat([xb, ab], dim=1)  # (B, cond_dim)
            B = cond.shape[0]

            logits = torch.empty((B, N_atoms), device=device, dtype=torch.float32)

            for j in range(0, N_atoms, atom_chunk_size):
                y_chunk = Y_atoms[j : j + atom_chunk_size]        # (C, y_dim)
                C = y_chunk.shape[0]

                # Build (B*C, y_dim) and (B*C, cond_dim) for this chunk only
                y_big = y_chunk.unsqueeze(0).expand(B, C, -1).reshape(B*C, -1)
                c_big = cond.unsqueeze(1).expand(B, C, -1).reshape(B*C, -1)

                logp = model.log_prob(y_big, c_big).view(B, C)  # GPU (B, C)
                logits[:, j:j+C] = logp.float()

                del y_chunk, y_big, c_big, logp


            # Softmax on CPU and write output
            out[i:i+B] = torch.softmax(logits, dim=1).cpu()
            del xb, ab, cond, logits

    return out
