import torch
import math
from extract_data import get_sbj_data
from bav_sampler_multi import nll_bav_constant_gaussian

class BAVModel:
    def __init__(self, sbj, RHO_A, data_path, idx_path, truncated):
        
        self.truncated = truncated
        self.RHO_A = RHO_A
        x, y = get_sbj_data(data_path, sbj, idx_path)
        print(f"Using {x.shape[0]} trials")
        self.R = y.reshape(-1).to(x.device, x.dtype)
        self.rt = x[:, 0].to(torch.long)
        self.vl = x[:, 1].to(torch.long)
        self.S_A = x[:, 2].to(x.dtype)
        self.S_V = x[:, 3].to(x.dtype)
    
    def log_prior(self, theta: torch.Tensor) -> torch.Tensor:
        """
        Compute the joint log prior for a batch of θ under fully Gaussian priors
        (no truncation), using the hierarchical means:

        0: log σV_low    ~ N(0.00, 1.5)
        1: log σV_med    ~ N(log σV_low + 1, 1)
        2: log σV_high   ~ N(log σV_med + 0.75, 0.5)
        3: log σA        ~ N(1.75, 0.5)
        4: log σs        ~ N(2.5, 1)
        5: log σm        ~ N(0.00, 0.5)
        6: logit(p_same) ~ N(1.5, 1.5)

        Args:
        theta: Tensor of shape [batch, 7]

        Returns:
        logp: Tensor of shape [batch], the joint log prior for each row.
        """
        if theta.ndim != 2 or theta.size(1) != 7:
            raise ValueError(f"theta must have shape [batch, 7], got {tuple(theta.shape)}")

        device, dtype = theta.device, theta.dtype
        log2pi = torch.log(torch.tensor(2.0 * math.pi, device=device, dtype=dtype))

        def gauss_logprob(x: torch.Tensor, mean: torch.Tensor, sd: float) -> torch.Tensor:
            sd_t = torch.as_tensor(sd, device=device, dtype=dtype)
            z = (x - mean) / sd_t
            return -0.5 * (z * z + 2.0 * torch.log(sd_t) + log2pi)

        # unpack columns
        log_sigmaV_low  = theta[:, 0]
        log_sigmaV_med  = theta[:, 1]
        log_sigmaV_high = theta[:, 2]
        log_sigmaA      = theta[:, 3]
        log_sigmas      = theta[:, 4]
        log_sigmam      = theta[:, 5]
        logit_p_same    = theta[:, 6]

        # hierarchical means
        mean_low   = torch.zeros_like(log_sigmaV_low)       # 0.0
        mean_med   = log_sigmaV_low + 1.0
        mean_high  = log_sigmaV_med + 0.75

        # independent means
        mean_A     = torch.full_like(log_sigmaV_low, 1.75)
        mean_s     = torch.full_like(log_sigmaV_low, 2.5)
        mean_m     = torch.zeros_like(log_sigmaV_low)       # 0.0
        mean_logit = torch.full_like(log_sigmaV_low, 1.5)

        # sum of log-densities (product of conditionals = joint)
        lp = (
            gauss_logprob(log_sigmaV_low,  mean_low,   1.5)
        + gauss_logprob(log_sigmaV_med,  mean_med,   1.0)
        + gauss_logprob(log_sigmaV_high, mean_high,  0.5)
        + gauss_logprob(log_sigmaA,      mean_A,     0.5)
        + gauss_logprob(log_sigmas,      mean_s,     1.0)
        + gauss_logprob(log_sigmam,      mean_m,     0.5)
        + gauss_logprob(logit_p_same,    mean_logit, 1.5)
        )
        return lp
    
    def log_prior_truncated(self, theta: torch.Tensor) -> torch.Tensor:
        """
        Joint log prior with each Gaussian truncated to [mean - 2*sd, mean + 2*sd].
        Outside that interval, use a constant floor equal to log N(5 | 0, 1).
        theta: [batch, 7] -> logp: [batch]
        """
        if theta.ndim != 2 or theta.size(1) != 7:
            raise ValueError(f"theta must have shape [batch, 7], got {tuple(theta.shape)}")

        device, dtype = theta.device, theta.dtype
        log2pi = torch.log(torch.tensor(2.0 * math.pi, device=device, dtype=dtype))

        # Truncation width (k standard deviations) and normalization term for symmetric truncation
        k = 2.0
        log_Z = torch.log(torch.tensor(math.erf(k / math.sqrt(2.0)), device=device, dtype=dtype))

        # Floor value: log N(5 | 0, 1) = -0.5*(5^2 + log(2π))
        log_floor = -0.5 * (torch.tensor(25.0, device=device, dtype=dtype) + log2pi)

        def trunc_gauss_logprob(x: torch.Tensor, mean: torch.Tensor, sd: float) -> torch.Tensor:
            sd_t = torch.as_tensor(sd, device=device, dtype=dtype)
            z = (x - mean) / sd_t
            # Truncated normal log-pdf inside the support
            logpdf = -0.5 * (z * z + log2pi) - torch.log(sd_t) - log_Z
            inside = (z >= -k) & (z <= k)
            # Use constant floor outside the support (broadcasts over batch)
            return torch.where(inside, logpdf, log_floor)

        # unpack columns
        log_sigmaV_low  = theta[:, 0]
        log_sigmaV_med  = theta[:, 1]
        log_sigmaV_high = theta[:, 2]
        log_sigmaA      = theta[:, 3]
        log_sigmas      = theta[:, 4]
        log_sigmam      = theta[:, 5]
        logit_p_same    = theta[:, 6]

        # hierarchical means
        mean_low   = torch.zeros_like(log_sigmaV_low)       # 0.0
        mean_med   = log_sigmaV_low + 1.0
        mean_high  = log_sigmaV_med + 0.75

        # independent means
        mean_A     = torch.full_like(log_sigmaV_low, 1.75)
        mean_s     = torch.full_like(log_sigmaV_low, 2.5)
        mean_m     = torch.zeros_like(log_sigmaV_low)       # 0.0
        mean_logit = torch.full_like(log_sigmaV_low, 1.5)

        # sum of truncated log-densities (product of conditionals = joint)
        lp = (
            trunc_gauss_logprob(log_sigmaV_low,  mean_low,   1.5)
        + trunc_gauss_logprob(log_sigmaV_med,  mean_med,   1.0)
        + trunc_gauss_logprob(log_sigmaV_high, mean_high,  0.5)
        + trunc_gauss_logprob(log_sigmaA,      mean_A,     0.5)
        + trunc_gauss_logprob(log_sigmas,      mean_s,     1.0)
        + trunc_gauss_logprob(log_sigmam,      mean_m,     0.5)
        + trunc_gauss_logprob(logit_p_same,    mean_logit, 1.5)
        )
        return lp


    def log_joint(self, theta):

        theta = torch.tensor(theta).unsqueeze(0)
        log_likelihood = - nll_bav_constant_gaussian(RHO_A = self.RHO_A, 
                                                     theta = theta, 
                                                     R = self.R, 
                                                     S_V = self.S_V, 
                                                     S_A = self.S_A, 
                                                     response_types = self.rt, 
                                                     V_levels = self.vl)

        if self.truncated:
            log_joint_torch =  log_likelihood + self.log_prior_truncated(theta)
        else:
            log_joint_torch =  log_likelihood + self.log_prior(theta)

        if log_joint_torch == float("-inf"):
            print('thetas:', theta)

        return log_joint_torch.sum().detach().numpy()