import torch
import torch.nn as nn
import numpy as np
from scipy.interpolate import interp1d
import faiss
from time import time
from tqdm import tqdm
import math

import sys
# APPEND PATH TO PROJECT CODE TO ENABLE IMPORTS
from utils.compression import truncated_2d_dct

class CFDM(nn.Module):
    """Closed-form diffusion model. Supports score and log-probability computation."""

    def __init__(self, train_samples, scheduler, num_dct_coeffs=None):
        super().__init__()
        self.num_timesteps = scheduler.config.num_train_timesteps
        self.train_samples = train_samples  # (N, *) where * is any shape like (D) or (C,H,W)
        self.original_shape = train_samples.shape[1:]
        self.num_dct_coeffs = num_dct_coeffs

        if self.num_dct_coeffs is not None:
            print(f"Applying DCT and truncating to {self.num_dct_coeffs} coefficients per H, W dimension")
            x_2d_dct = truncated_2d_dct(train_samples, self.num_dct_coeffs, to_cpu=True).to(train_samples)  # (N, C, num_dct_coeffs, num_dct_coeffs)
            self.x_flat_dct = x_2d_dct.view(train_samples.shape[0], -1)  # (N, C * num_dct_coeffs * num_dct_coeffs)
            print(f"Shape after 2D DCT: {self.x_flat_dct.shape}")
            # Clear the CUDA cache to free up memory
            torch.cuda.empty_cache()

        betas_np = scheduler.betas.cpu().numpy()
        betas_rescaled = betas_np * scheduler.config.num_train_timesteps
        alpha_bars_np = scheduler.alphas_cumprod.cpu().numpy()
        sigmas_np = np.sqrt(1 - alpha_bars_np)

        _ts = np.linspace(0, 1, len(alpha_bars_np))
        self.beta_fn = interp1d(_ts, betas_rescaled, kind="linear", fill_value="extrapolate")
        self.alpha_bar_fn = interp1d(_ts, alpha_bars_np, kind="linear", fill_value="extrapolate")
        self.sigma_fn = interp1d(_ts, sigmas_np, kind="linear", fill_value="extrapolate")

    def flatten(self, x):
        return x.view(x.shape[0], -1)

    def unflatten(self, x):
        return x.view(x.shape[0], *self.original_shape)
    
    def chunked_weighted_sum(self,
                             weights: torch.Tensor, 
                             x_t: torch.Tensor, 
                             chunk_size: int = 1024
                             ):
        """
        Args:
            weights: (B, N) tensor
            x_t: (N, D) tensor
            chunk_size: how many vectors to process at a time
        Returns:
            avg_x: (B, D) tensor
        """
        B, N = weights.shape
        D = x_t.shape[1]
        avg_x = torch.zeros(B, D, dtype=weights.dtype, device=weights.device)

        for start in range(0, N, chunk_size):
            end = min(start + chunk_size, N)
            w_chunk = weights[:, start:end]         # (B, chunk)
            x_chunk = x_t[start:end]                # (chunk, D)
            avg_x += w_chunk @ x_chunk              # (B, chunk) @ (chunk, D) -> (B, D)

        return avg_x

    def forward(self, z, t):
        """Compute the score s(z,t) at position z and time t."""
        if self.num_dct_coeffs is not None:
            x_for_dists = self.x_flat_dct  # (N, num_dct_coeffs)
        else:
            x_for_dists = self.flatten(self.train_samples)

        x = self.train_samples  # (N, *)

        if self.num_dct_coeffs is not None:
            z_for_dists = truncated_2d_dct(z, self.num_dct_coeffs, to_cpu=True).to(z)
            z_for_dists = self.flatten(z_for_dists)  # (B, C*num_dct_coeffs**num_dct_coeffs)
            z = self.flatten(z)  # (B, D)
        else:
            z = self.flatten(z) # (B, D)
            z_for_dists = z # (B, D)

        x = self.flatten(x)  # (N, D)

        if isinstance(t, torch.Tensor):
            t = t.cpu().item()

        if t == 1:
            # target_mean = torch.mean(x, dim=0)
            # score = target_mean - z  # (B, D)
            score = -z  # (B, D)
            return self.unflatten(score)
        else:
            alpha_bar = float(self.alpha_bar_fn(t))
            sigma = float(self.sigma_fn(t))

            x_t = x * math.sqrt(alpha_bar)  # (N, D)
            x_t_for_dists = x_for_dists * math.sqrt(alpha_bar)  # (N, D) or (N, num_dct_coeffs)
            # Adjust the shape of z to (B, 1, space_dims) for broadcasting

            # Compute weights: ||z[i] - tx[j]||^2 / (2 * sigma^2)
            dist_sq = torch.cdist(z_for_dists, x_t_for_dists, p=2) ** 2  # (B, N) -- contrary to the Pytorch docs, cdist also works with 2D tensors
            log_weights = -dist_sq / (2 * sigma**2)
            log_weights = log_weights - torch.max(log_weights, dim=1, keepdim=True)[0]  # numerical stability

            weights = torch.exp(log_weights)  # (B, N)
            # Normalize the weights
            weights_sum = torch.sum(weights, dim=1, keepdim=True) + 1e-12  # avoid division by zero 
            weights = weights / weights_sum  # (B, N)

            # Compute the weighted average of x_t for each batch element
            if self.num_dct_coeffs is not None:
                # Chunk the avg_x computation to avoid memory issues    
                avg_x = self.chunked_weighted_sum(weights, x_t, chunk_size=1024)  # (B, D)
            else:
                avg_x = torch.sum(weights[:, :, None] * x_t[None, :, :], dim=1)  # (B, D)
            score = (avg_x - z) / (sigma**2)  # (B, D)

            return self.unflatten(score)

    def log_prob(self, z, t):
        if self.num_dct_coeffs is not None:
            z = truncated_2d_dct(z, self.num_dct_coeffs, to_cpu=True).to(z)  # (B, C, num_dct_coeffs, num_dct_coeffs)
            z = self.flatten(z)  # (B, C * num_dct_coeffs * num_dct_coeffs)
            x = self.x_flat_dct  # (N, num_dct_coeffs)
        else:
            z = self.flatten(z) # (B, D)
            x = self.flatten(self.train_samples)  # (N, D)
        B, D = z.shape
        N = x.shape[0]

        if isinstance(t, torch.Tensor):
            t = t.cpu().item()

        alpha_bar = float(self.alpha_bar_fn(t))
        sigma = float(self.sigma_fn(t))
        variance = sigma**2 + 1e-6

        x_t = x * math.sqrt(alpha_bar)  # (N, D) or (N, num_dct_coeffs)
        sq_dist = torch.cdist(z, x_t, p=2) ** 2  # (B, N)
        mahalanobis_dist_squared = sq_dist / variance  # (B, N)

        log_det_cov = D * math.log(variance)  # stays in float
        log_normalizer = -0.5 * (D * math.log(2 * math.pi) + log_det_cov)
        log_probs_per_component = log_normalizer - 0.5 * mahalanobis_dist_squared  # (B, N)

        log_probs = torch.logsumexp(log_probs_per_component, dim=1) - math.log(N)  # (B,)
        return log_probs

    def pf_ode_func(self, z, t):
        score_val = self.forward(z, t)
        if isinstance(t, torch.Tensor):
            t = t.cpu().item()
        return -0.5 * self.beta_fn(t) * (z + score_val)

class CFDM_NN(nn.Module):
    def __init__(self, train_samples, scheduler, K=10, L=10, use_gpu=True, device_id=0, num_dct_coeffs=None):
        super().__init__()
        self.num_timesteps = scheduler.config.num_train_timesteps
        self.train_samples = train_samples
        self.original_shape = train_samples.shape[1:]
        self.K = K
        self.L = L
        self.N = train_samples.shape[0]
        self.num_dct_coeffs = num_dct_coeffs

        x_flat = train_samples.view(self.N, -1)
        print(f"Flattened train_samples shape: {x_flat.shape}")

        # If num_dct_coeffs is specified, apply DCT and truncate
        if self.num_dct_coeffs is not None:
            print(f"Applying DCT and truncating to {self.num_dct_coeffs} coefficients")
            # self.x_flat_dct = truncated_dct(x_flat, self.num_dct_coeffs, to_cpu=True).float() # (N, num_dct_coeffs), lives on CPU
            x_2d_dct = truncated_2d_dct(train_samples, self.num_dct_coeffs, to_cpu=True).to(train_samples)  # (N, C, num_dct_coeffs, num_dct_coeffs)
            self.x_flat_dct = x_2d_dct.view(train_samples.shape[0], -1)  # (N, C * num_dct_coeffs * num_dct_coeffs)
            print(f"Shape after DCT: {self.x_flat_dct.shape}")
            # Clear the CUDA cache to free up memory
            torch.cuda.empty_cache()
            x_flat = self.x_flat_dct.cpu().numpy().astype(np.float32)
        else:
            print("No DCT applied, using full flattened representation")
            x_flat = x_flat.cpu().numpy().astype(np.float32)

        # index_flat = faiss.IndexFlatL2(x_flat.shape[1])
        # # Create coarse quantizer (used to assign points to clusters)
        d = x_flat.shape[1]
        quantizer = faiss.IndexFlatL2(d)
        # Create IVF index (nlist is the number of clusters)
        nlist = 100
        index_ivf = faiss.IndexIVFFlat(quantizer, d, nlist)
        if use_gpu:
            res = faiss.StandardGpuResources()
            # Extract GPU ID from device of train_samples
            gpu_id = device_id
            print(f"Using GPU ID: {gpu_id}")
            self.faiss_index = faiss.index_cpu_to_gpu(res, gpu_id, index_ivf)
            # self.faiss_index = faiss.index_cpu_to_gpu(res, gpu_id, index_flat)
        else:
            print("Using CPU for FAISS index")
            self.faiss_index = index_ivf
            # self.faiss_index = index_flat

        self.faiss_index.train(x_flat)
        self.faiss_index.add(x_flat)

        self.faiss_index.nprobe = 10  # number of clusters to search

        betas_np = scheduler.betas.cpu().numpy()
        betas_rescaled = betas_np * scheduler.config.num_train_timesteps
        alpha_bars_np = scheduler.alphas_cumprod.cpu().numpy()
        sigmas_np = np.sqrt(1 - alpha_bars_np)

        _ts = np.linspace(0, 1, len(alpha_bars_np))
        self.beta_fn = interp1d(_ts, betas_rescaled, kind="linear", fill_value="extrapolate")
        self.alpha_bar_fn = interp1d(_ts, alpha_bars_np, kind="linear", fill_value="extrapolate")
        self.sigma_fn = interp1d(_ts, sigmas_np, kind="linear", fill_value="extrapolate")

    def flatten(self, x):
        return x.view(x.shape[0], -1)

    def unflatten(self, x):
        return x.view(x.shape[0], *self.original_shape)
    
    def replace_invalid_indices_with_uniform_row_samples(self, I):
        """
        Replace -1s in each row of I with samples from valid entries using torch.multinomial.
        I: LongTensor of shape (B, K)
        Returns: LongTensor of same shape, with -1s replaced.
        """
        B, K = I.shape
        I_filled = I.clone()

        # Mask: valid = 1, invalid = 0
        valid_mask = (I != -1).float()  # (B, K)
        num_to_sample = (I == -1).sum(dim=1)  # (B,)

        if num_to_sample.max() == 0:
            return I_filled  # Nothing to replace

        # Build uniform probs over valid entries
        probs = valid_mask / valid_mask.sum(dim=1, keepdim=True)  # (B, K)

        # Max number of samples needed across all rows
        max_samples = num_to_sample.max().item()

        # Sample with replacement from each row
        sampled_indices = torch.multinomial(probs, max_samples, replacement=True)  # (B, max_samples)

        # Get the actual values from I
        valid_values = I.gather(1, sampled_indices)  # (B, max_samples)

        # Now fill in the -1s
        row_idx, col_idx = (I == -1).nonzero(as_tuple=True)
        fill_positions = num_to_sample.cumsum(0)  # end indices per row
        fill_positions = torch.cat([fill_positions.new_zeros(1), fill_positions[:-1]])
        offsets = row_idx * max_samples + (torch.arange(len(row_idx), device=I.device) - fill_positions[row_idx])

        sampled_flat = valid_values.view(-1)
        I_filled[row_idx, col_idx] = sampled_flat[offsets]

        return I_filled

    def sample_excluding_neighbors(self, I):
        """
        Uniformly sample L indices from [0, N) \ I for each row in batch using torch.multinomial.
        I: (B, K) tensor of excluded indices
        Returns: (B, L) tensor of valid random indices per batch
        """
        B, K = I.shape
        device = I.device

        # Build uniform weights with zeros at positions to exclude
        weights = torch.ones((B, self.N), device=device)
        assert (I >= 0).all(), f"Negative indices: {I.min()}"
        assert (I < self.N).all(), f"Out of bounds indices: {I.max()} >= {self.N}"
        weights.scatter_(1, I, 0.0)  # set excluded indices to 0

        # Sample without replacement from the valid indices
        sampled = torch.multinomial(weights, num_samples=self.L, replacement=False)  # (B, L)

        return sampled

    def get_neighbors_and_random(self, z_flat, t, device):
        # z_flat = self.flatten(z).detach().contiguous().cpu().numpy().astype(np.float32)
        # faiss.normalize_L2(z_flat)  # Optional for cosine similarity
        alpha_bar = self.alpha_bar_fn(t)
        sqrt_alpha_bar = float(math.sqrt(alpha_bar))

        D, I = self.faiss_index.search(z_flat / sqrt_alpha_bar, self.K)
        I = torch.tensor(I, dtype=torch.long, device=device)
        # Post-process to replace -1s with uniform samples from the valid entries
        I = self.replace_invalid_indices_with_uniform_row_samples(I)  # (B, K)

        rand_idx = self.sample_excluding_neighbors(I)  # (B, L)

        return I, rand_idx

    def forward(self, z, t, indices=None, return_indices=False):
        if isinstance(t, torch.Tensor):
            t = t.cpu().item()

        # z_flat = self.flatten(z)  # (B, D)
        # Compress z_flat to DCT coefficients if specified
        if self.num_dct_coeffs is not None:
            z_2d_dct = truncated_2d_dct(z, self.num_dct_coeffs, to_cpu=True).to(z)
            z_flat_for_dists = self.flatten(z_2d_dct)  # (B, C * num_dct_coeffs * num_dct_coeffs)
            z_flat = self.flatten(z)  # (B, D)
        else:
            z_flat = self.flatten(z)  # (B, D)
            z_flat_for_dists = z_flat # (B, D)

        alpha_bar = float(self.alpha_bar_fn(t))
        sigma = float(self.sigma_fn(t))
        variance = sigma ** 2

        x_flat = self.flatten(self.train_samples)  # (N, D)

        if self.num_dct_coeffs is not None:
            x_flat_for_dists = self.x_flat_dct  # (N, num_dct_coeffs)
        else:
            x_flat_for_dists = x_flat # (N, D)
        z_np = z_flat_for_dists.detach().contiguous().cpu().numpy().astype(np.float32)

        if indices is None:
            I_nn, I_rand = self.get_neighbors_and_random(z_np, t, z.device)  # (B, K), (B, L)
            indices = torch.cat([I_nn, I_rand], dim=1)  # (B, K+L)
        x_selected_for_dists = x_flat_for_dists[indices.cpu()].to(z_flat_for_dists) # (B, K+L, D) or (B, K+L, num_dct_coeffs)
        x_selected = x_flat[indices] # (B, K+L, D)
        x_t_for_dists = math.sqrt(alpha_bar) * x_selected_for_dists  # (B, K+L, D) or (B, K+L, num_dct_coeffs)
        x_t = math.sqrt(alpha_bar) * x_selected  # (B, K+L, D)

        dist_sq = torch.cdist(z_flat_for_dists.unsqueeze(1), x_t_for_dists, p=2).squeeze() ** 2  # (B, K+L)
        log_weights = -dist_sq / (2 * variance)  # unnormalized log weights

        # Subtract max log weight for numerical stability
        log_weights = log_weights - torch.max(log_weights, dim=1, keepdim=True)[0]  # (B, K+L)

        # Rescale weights for NN and random
        weights = torch.exp(log_weights)  # (B, K+L)
        weights_nn = weights[:, :self.K] * (1 / self.N)
        weights_rand = weights[:, self.K:] * ((self.N - self.K) / (self.L * self.N))
        weights = torch.cat([weights_nn, weights_rand], dim=1)

        weights_sum = weights.sum(dim=1, keepdim=True) + 1e-12
        weights = weights / weights_sum  # normalize weights (B, K+L)

        if self.num_dct_coeffs is not None:
            # Chunk the avg_x computation to avoid memory issues
            chunk_size = 1024
            B, M, D = x_t.shape  # M = K + L
            avg_x_chunks = []
            for start in range(0, D, chunk_size):
                end = min(start + chunk_size, D)
                x_chunk = x_t[:, :, start:end]  # (B, M, chunk)
                weighted_chunk = weights.unsqueeze(-1) * x_chunk  # (B, M, chunk)
                avg_chunk = weighted_chunk.sum(dim=1)  # (B, chunk)
                avg_x_chunks.append(avg_chunk)
            avg_x = torch.cat(avg_x_chunks, dim=1)  # (B, D)
        else:
            avg_x = torch.sum(weights.unsqueeze(-1) * x_t, dim=1)  # (B, D)
        score = (avg_x - z_flat) / variance  # (B, D)

        if return_indices:
            # Return the indices of the neighbors and random samples
            return self.unflatten(score), indices
        else:
            return self.unflatten(score)

    def log_prob(self, z, t, indices=None, return_indices=False):
        if isinstance(t, torch.Tensor):
            t = t.cpu().item()

        # z_flat = self.flatten(z)  # (B, D)
        # Compress z_flat to DCT coefficients if specified
        if self.num_dct_coeffs is not None:
            # z_flat = truncated_dct(z_flat, self.num_dct_coeffs, batch_size=z_flat.shape[0]).to(z) # (B, num_dct_coeffs)
            z_2d_dct = truncated_2d_dct(z, self.num_dct_coeffs, to_cpu=True).to(z)
            z_flat = self.flatten(z_2d_dct)  # (B, C * num_dct_coeffs * num_dct_coeffs)
        else:
            z_flat = self.flatten(z) # (B, D)
        B, D = z_flat.shape

        alpha_bar = float(self.alpha_bar_fn(t))
        sigma = float(self.sigma_fn(t))
        variance = sigma ** 2 + 1e-6 # Should there be a factor of 2 here?

        if self.num_dct_coeffs is not None:
            x_flat = self.x_flat_dct # (N, num_dct_coeffs)
        else:
            x_flat = self.flatten(self.train_samples) # (N, D)
        z_np = z_flat.detach().contiguous().cpu().numpy().astype(np.float32) 
        if indices is None:
            I_nn, I_rand = self.get_neighbors_and_random(z_np, t, z.device)  # (B, K), (B, L)
            indices = torch.cat([I_nn, I_rand], dim=1)  # (B, K+L)

        x_selected = x_flat[indices.cpu()].to(z_flat)  # (B, K+L, D) or (B, K+L, num_dct_coeffs)
        mean = math.sqrt(alpha_bar) * x_selected  # (B, K+L, D) or (B, K+L, num_dct_coeffs)

        mdist_sq = torch.cdist(z_flat.unsqueeze(1), mean, p=2).squeeze() ** 2 / variance  # (B, K+L)

        log_det = D * torch.log(torch.tensor(variance, dtype=z.dtype, device=z.device))
        log_norm = -0.5 * (D * math.log(2 * math.pi) + log_det)
        log_probs = log_norm - 0.5 * mdist_sq  # (B, K+L)

        # Apply weights before logsumexp
        weights = torch.ones_like(log_probs)
        weights[:, :self.K] *= 1 / self.N
        weights[:, self.K:] *= (self.N - self.K) / (self.L * self.N)

        weighted_log_probs = log_probs + torch.log(weights + 1e-12)
        log_prob = torch.logsumexp(weighted_log_probs, dim=1)  # (B,)

        if return_indices:
            # Return the indices of the neighbors and random samples
            return log_prob, indices
        else:
            return log_prob
    
    def pf_ode_func(self, z, t):
        score_val = self.forward(z, t)
        if isinstance(t, torch.Tensor):
            t = t.cpu().item()
        return -0.5 * self.beta_fn(t) * (z + score_val)