import torch
from BSpline import BSpline
from torch import Tensor
from torch import nn
from tqdm import tqdm

def inv_sqrt_matrix(M: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    w, Q = torch.linalg.eigh(M)
    w_inv_sqrt = (w.clamp(min=eps) ** -0.5)
    D_inv_sqrt = torch.diag(w_inv_sqrt)
    return Q @ D_inv_sqrt @ Q.T

def compute_beta_hat(
    z_u_list: list[torch.Tensor],
    z_v_list: list[torch.Tensor],
    bspline
) -> torch.Tensor:
    device = z_u_list[0].device
    d = bspline.n_bases

    u_means, u_feats, u_tokens_num = [], [], []
    for z in z_u_list:
        z = z.squeeze(0).to(device)                
        fz = bspline.predict(z.clamp_min(bspline.start)).to(device) 
        u_means.append(fz.mean(dim=0))
        u_feats.append(fz)
        u_tokens_num.append(fz.shape[0])
    v_means, v_feats, v_tokens_num = [], [], []
    for z in z_v_list:
        z = z.squeeze(0).to(device)
        fz = bspline.predict(z.clamp_min(bspline.start)).to(device)
        v_means.append(fz.mean(dim=0))
        v_feats.append(fz)
        v_tokens_num.append(fz.shape[0])

    delta_u = torch.stack(u_means, dim=0).sum(dim=0)
    delta_v = torch.stack(v_means, dim=0).sum(dim=0)
    delta   = delta_u - delta_v

    Sigma_u = torch.zeros((d, d), device=device)
    for i, Fu in enumerate(u_feats):
        Fu_c = Fu - Fu.mean(dim=0, keepdim=True)   
        Sigma_u += ((Fu_c.T @ Fu_c) / (Fu_c.shape[0] - 1)) / u_tokens_num[i]  
    Sigma_v = torch.zeros((d, d), device=device)
    for i, Fv in enumerate(v_feats):
        Fv_c = Fv - Fv.mean(dim=0, keepdim=True)   
        Sigma_v += ((Fv_c.T @ Fv_c) / (Fv_c.shape[0] - 1)) / v_tokens_num[i]  

    Sigma  = Sigma_u + Sigma_v                           

    Sigma = inv_sqrt_matrix(Sigma)
    beta_tilde = Sigma @ delta       # (d,)
    beta_hat   = beta_tilde / beta_tilde.norm(p=2)
    return beta_hat

def get_zij(text_list, tokenizer, model, args):
    model.eval()

    n_samples = len(text_list)
    z_list = []
    for idx in tqdm(range(n_samples)):
        original_text = text_list[idx]
        tokenized = tokenizer(original_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
        labels = tokenized.input_ids[:, 1:]
        with torch.no_grad():
            logits_score = model(**tokenized).logits[:, :-1]
        labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels
        z_j_b = torch.log_softmax(logits_score, dim=-1)
        z_j = z_j_b.gather(dim=-1, index=labels).squeeze(-1)
        z_list.append(z_j)

    return z_list

class BSplineW(nn.Module):
    def __init__(self, bspline_args):
        super().__init__()
        self.bspline = BSpline(**bspline_args)
        pass

    def fit(self, data, tokenizer, model, args):
        print("Learning w function...")
        print("Fetch log-likelihood of human texts...")
        z_ij_u = get_zij(data['sampled'], tokenizer, model, args)
        print("Fetch log-likelihood of LLM texts...")
        z_ij_v = get_zij(data['original'], tokenizer, model, args)
        print("Computing beta_hat...")
        beta_hat = compute_beta_hat(z_ij_u, z_ij_v, self.bspline)
        self.beta_hat = beta_hat
        print("beta_hat:", torch.round(beta_hat, decimals=3))

    def forward(self, input: Tensor):
        input_shape = input.shape
        device = input.device
        flat = input.clamp_min(self.bspline.start).reshape(-1)
        w_value = self.bspline.predict(flat).to(device) @ self.beta_hat.to(device)
        w_value = w_value.reshape(input_shape)
        return w_value

def get_ci_list(text_list, tokenizer, model, w_fun, args):
    model.eval()

    n_samples = len(text_list)
    c_list = []
    for idx in tqdm(range(n_samples)):
        original_text = text_list[idx]
        tokenized = tokenizer(original_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
        labels = tokenized.input_ids[:, 1:]
        with torch.no_grad():
            logits_score = model(**tokenized).logits[:, :-1]
        labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels
        z_j_b = w_fun(torch.log_softmax(logits_score, dim=-1))
        probs_ref = torch.softmax(logits_score, dim=-1)
        mean_ref = (probs_ref * z_j_b).sum(dim=-1)
        z_j = z_j_b.gather(dim=-1, index=labels).squeeze(-1)
        
        ci = (z_j.mean(dim=-1) - mean_ref.mean(dim=-1))[0]
        c_list.append(ci)
    return c_list

class ShiftLearner(nn.Module):
    def __init__(self):
        super().__init__()
        pass

    def fit(self, data, tokenizer, model, w_func, args):
        print("Learning shift...")
        ci_hat_list = get_ci_list(data['original'], tokenizer, model, w_func, args)
        c_hat = torch.mean(torch.tensor(ci_hat_list))
        self.c_hat = c_hat
        print("c_hat:", torch.round(c_hat, decimals=3))
