#%%
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange
from fasttransform import Transform
from typing import Union
from pathlib import Path
import functools as ft
import numpy as np
from tqdm.auto import tqdm
from memories import EpaMemory, LseMemory
from fastcore.foundation import patch
import jax
import jaxlib
import functools as ft
from diffusers import DiffusionPipeline, AutoencoderTiny

class TAESD_Wrapper(nn.Module):
    def __init__(self, vae: AutoencoderTiny):
        super().__init__()
        self.vae = vae
        self.preprocess = Transform(lambda x: x, lambda x: x)
        self.latent_process = Transform(
            lambda z: z.reshape(z.shape[0], -1),
            lambda z: z.reshape(z.shape[0], 4, 8, 8) # Hardcoded for tiny imagenet
        )

    def encode(self, x):
        mu = self.vae.encode(x)
        mu = mu.latents
        mu = self.latent_process(mu)
        logvar = torch.zeros_like(mu) # Logvar isn't used in TAESD
        return mu, logvar

    def reparameterize(self, mu, logvar):
        return mu

    def decode(self, z):
        z = self.latent_process.decode(z)
        return self.vae.decode(z).sample

    def encode_x(self, x):
        """From img to latent with preprocessing"""
        with torch.no_grad():
            x_normalized = self.preprocess(x)
            z, _ = self.encode(x_normalized)
        return z

    def decode_z(self, z):
        """From latent to img with postprocessing"""
        with torch.no_grad():
            xhat = self.decode(z).sample
            xhat = self.preprocess.decode(xhat)
        return xhat

    def to(self, device):
        self.vae.to(device)
        return self

    def eval(self):
        self.vae.eval()
        return self

    @property
    def device(self):
        return self.vae.device

    

@ft.lru_cache(maxsize=None)
def load_taesd_vae():
    pipe = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16
    )
    pipe = pipe.to("cuda")
    pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch.float32)
    vae = pipe.vae 
    vae.eval()

    return TAESD_Wrapper(vae)

class BetaVAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=512, latent_dim=32, beta=4.0):
        super().__init__()
        self.beta = beta
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(hidden_dim)
        )
        
        # Mean and variance layers
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )
        
    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar
    
    def loss_function(self, x, x_recon, mu, logvar):
        # Reconstruction loss (binary cross entropy)
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
        # KL divergence
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        # Total loss with beta weighting
        return recon_loss + self.beta * kl_loss

    @property
    def device(self):
        return next(self.parameters()).device

data_transform = Transform(
    lambda x: rearrange(x, "... h w -> ... (h w)"), 
    lambda xenc: (
        h := int(np.sqrt(xenc.shape[-1])),
        rearrange(xenc, "... (h w) -> ... h w", h=h, w=h)
    )[-1]
)

@ft.lru_cache
def load_bvae(
    path:Union[str, Path], # path to the model checkpoint
    beta=4.0, # beta value used to train the model. Doesn't affect inference
    ):
    """
    Example usage:

        model = load_bvae("beta_vae_mnist.pt")
    """
    # Load the model checkpoint
    state_dict = torch.load(path)

    # Get the dimensions of the model from checkpoint
    W1 = state_dict["encoder.0.weight"]
    Wf = state_dict["fc_mu.weight"]
    input_dim = W1.shape[1]
    hidden_dim = W1.shape[0]
    latent_dim = Wf.shape[0]

    # Initialize the model
    model = BetaVAE(input_dim=input_dim, hidden_dim=hidden_dim, latent_dim=latent_dim, beta=beta)
    model.load_state_dict(state_dict)
    return model

@ft.lru_cache
def load_data(dataset="mnist"):
    if dataset == "mnist":
        Xtrain = torch.tensor(np.load("data/mnist/Xtrain.npy"), dtype=torch.float32)
        Xtest = torch.tensor(np.load("data/mnist/Xtest.npy"), dtype=torch.float32)
    elif dataset == "cifar":
        Xtrain = torch.tensor(np.load("data/cifar10/Xtrain.npy"), dtype=torch.float32)
        Xtest = torch.tensor(np.load("data/cifar10/Xtest.npy"), dtype=torch.float32)
        Xtrain = rearrange(Xtrain, "b h w c -> b c h w")
        Xtest = rearrange(Xtest, "b h w c -> b c h w")
    elif dataset == "tinyimagenet":
        Xtrain = torch.tensor(np.load("data/tiny-imgnet/Xtrain.npy"), dtype=torch.float32)
        Xtest = torch.tensor(np.load("data/tiny-imgnet/Xtest.npy"), dtype=torch.float32)
        Xtrain = rearrange(Xtrain, "b h w c -> b c h w")
        Xtest = rearrange(Xtest, "b h w c -> b c h w")
    else:
        raise ValueError(f"Dataset {dataset} not supported")

    return Xtrain, Xtest

def batch_encode_data(model, data, batch_size=256, do_transform=True):
    latents = []
    mus = []
    logvars = []

    print(f"Encoding {len(data)} images...")
    model = model.eval()
    with torch.no_grad():
        for i in tqdm(range(0, len(data), batch_size)):
            batch = data[i:i+batch_size]
            batch = batch.to(model.device)
            batch_transformed = data_transform(batch) if do_transform else batch
            mu, logvar = model.encode(batch_transformed)
            z = model.reparameterize(mu, logvar)
            
            latents.append(z.cpu().numpy())
            mus.append(mu.cpu().numpy())
            logvars.append(logvar.cpu().numpy())

    return np.concatenate(latents), np.concatenate(mus), np.concatenate(logvars)

MEMS = {
    "epa": EpaMemory(eps=0., lmda=0.),
    "lsr": EpaMemory(eps=0., lmda=0.),
    "lse": LseMemory(),
}

@patch
def numpify(self: torch.Tensor) -> np.ndarray:
    return self.detach().cpu().numpy()

@patch
def numpify(self: jax.Array) -> np.ndarray:
    return np.array(self)

@patch
def numpify(self: jaxlib.xla_extension.ArrayImpl) -> np.ndarray:
    return np.array(self)
