import torch
from pathlib import Path
import json
from typing import Literal
from dataclasses import dataclass
from tqdm import tqdm

@dataclass 
class KroneckerFactorizedCovariance:
    A_inv: torch.Tensor
    B_inv: torch.Tensor

    def clone(self):
        return KroneckerFactorizedCovariance(
            A_inv=self.A_inv.clone(),
            B_inv=self.B_inv.clone(),
        )
    
    def to(self, device):
        self.A_inv = self.A_inv.to(device)
        self.B_inv = self.B_inv.to(device)
        return self

def l2_norm_squared(module: torch.nn.Module):
    return sum((p ** 2).sum() for p in module.parameters())

def num_params(module: torch.nn.Module):
    return sum(p.numel() for p in module.parameters())

def compute_log_prior(l2_norm_squared: torch.Tensor, num_params: int, lmbda: float):
    return -0.5 * lmbda * l2_norm_squared + 0.5 * num_params * torch.log(lmbda)

def compute_log_det_kfac(A: torch.Tensor, B: torch.Tensor):
    logdet_A = torch.logdet(A)
    logdet_B = torch.logdet(B)
    p, q = A.shape[0], B.shape[0]
    return logdet_A * p + logdet_B * q

def optimize_prior_precision(
    projection: torch.nn.Module,
    A: torch.Tensor,
    B: torch.Tensor,
    lmbda_init: float,
    n: float,
    lr: float,
    num_steps: int,
    device: str,
    retain_graph: bool = False,
    verbose: bool = False
) -> torch.Tensor:
    for param in projection.parameters():
        param.requires_grad = False

    projection_norm = l2_norm_squared(projection)
    num_params_projection = num_params(projection)

    A = A.to(device)
    B = B.to(device)

    # optimize prior precision
    log_lmbda = torch.nn.Parameter(
        torch.tensor(lmbda_init, device=device, requires_grad=True, dtype=torch.float32).log()
    )
    sqrt_n = torch.tensor(n, device=device, requires_grad=False, dtype=torch.float32).sqrt()

    optimizer = torch.optim.Adam([log_lmbda], lr=lr, maximize=True)

    for epoch in tqdm(range(num_steps), total=num_steps, disable=not verbose):
        optimizer.zero_grad()

        lmbda = log_lmbda.exp()
        sqrt_lmbda = lmbda.sqrt()

        # add prior to the loss
        A_ = A * sqrt_n + sqrt_lmbda * torch.eye(A.shape[0], device=device, dtype=A.dtype)
        B_ = B * sqrt_n + sqrt_lmbda * torch.eye(B.shape[0], device=device, dtype=B.dtype)

        log_prior = compute_log_prior(projection_norm, num_params_projection, lmbda)
        log_det = compute_log_det_kfac(A_, B_)
        marglik = log_prior - log_det

        marglik.backward(retain_graph=retain_graph)
        optimizer.step()

    return log_lmbda.exp()

def _compute_covariance(
    A: torch.Tensor,
    B: torch.Tensor,
    n: torch.Tensor,
    lmbda: torch.Tensor,
):
    sqrt_n = torch.sqrt(n)
    sqrt_lmbda = torch.sqrt(lmbda)
    A = A * sqrt_n + sqrt_lmbda * torch.eye(A.size(0), device=A.device, dtype=A.dtype)
    B = B * sqrt_n + sqrt_lmbda * torch.eye(B.size(0), device=B.device, dtype=B.dtype)

    return KroneckerFactorizedCovariance(
        A_inv=torch.linalg.inv(A),
        B_inv=torch.linalg.inv(B),
    )

def compute_covariances(
    A_img: torch.Tensor,
    B_img: torch.Tensor,
    A_txt: torch.Tensor,
    B_txt: torch.Tensor,
    info: dict,
):
    n_img = torch.tensor(info['n_img'], dtype=A_img.dtype, device=A_img.device)
    n_txt = torch.tensor(info['n_txt'], dtype=A_txt.dtype, device=A_txt.device)
    lambda_img = torch.tensor(info['lambda_img'], dtype=A_img.dtype, device=A_img.device)
    lambda_txt = torch.tensor(info['lambda_txt'], dtype=A_txt.dtype, device=A_txt.device)

    cov_img = _compute_covariance(A_img, B_img, n_img, lambda_img)
    cov_txt = _compute_covariance(A_txt, B_txt, n_txt, lambda_txt)
    return cov_img, cov_txt

def load_hessians(
    tag: Literal['img', 'txt'],
    model_name: str = "ViT-B-32",
    return_info: bool = False,
):

    hessians_avail = {
        "ViT-B-32" : 'bayesvlm/hessian_CLIP-ViT-B-32-laion2B-s34B-b79K'
    }

    if model_name in  hessians_avail:
        la_dir = hessians_avail[model_name]
    else:
        raise ValueError(f"The only hessian currently availables are {list(hessians_avail.keys())}")

    A = torch.load(Path(la_dir) / f'A_{tag}_analytic.pt', map_location='cpu')
    B = torch.load(Path(la_dir) / f'B_{tag}_analytic.pt', map_location='cpu')

    if not return_info:
        return A, B

    with open(Path(la_dir) / f'prior_precision_analytic.json') as f:
        info = json.load(f)

    return A, B, info