import random
import torch
import argparse
from torch import vmap, logdet, det
from torch.func import jacfwd
from torch.nn.functional import softplus
import wandb
import os
from tqdm import tqdm
from matplotlib import pyplot as plt
import numpy as np

from autoencoder import Autoencoder
from data import get_data
from plot import plot_latent_grid_2d

def train(args):
    # Set device
    device = torch.device(f"cuda:{args.cuda}")

    # Set parameters for varying latents
    n_rows = args.n_rows
    n_steps = args.n_steps

    # Init experiment
    model = Autoencoder(args.latent_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    train_loader, val_loader, digit_loaders = get_data(args)
    
    # Dir for checkpoints
    ckpt_dir = f'ckpt/{wandb.run.name}/'
    os.makedirs(ckpt_dir, exist_ok=True)

    # Train loop
    rho = None
    rho_history = []
    for epoch in range(args.num_iters):
        # Current iteration
        print(f"Iteration: {epoch+1}/{args.num_iters}")
        model.train()
        train_loss, train_recon, train_vol, train_jacnorm = 0.0, 0.0, 0.0, 0.0
        for x in tqdm(train_loader, total=len(train_loader)):
            # Load data to device
            x = x[0].to(device)

            # Forward pass
            xh, sh = model(x)

            # Recon loss
            recon = (x - xh).square().mean()
            
            # Get Jacobian via vmap and jacfwd
            jac = vmap(jacfwd(model.decoder))(sh)

            # Volume reg
            vol = logdet(jac.transpose(-1, -2) @ jac + 1.0*torch.eye(jac.shape[-1], device=jac.device).unsqueeze(0))

            # Jacobian L1 norm
            jacnorm = jac.abs().sum(dim=(1, 2))
                    
            # Total loss
            alpha = (args.lam_vol / args.warmup) * np.minimum(epoch, args.warmup)
            if epoch <= args.warmup:
                loss = recon - alpha * vol.mean() + args.lam_norm * jacnorm.mean()
                rho_history.append(jacnorm.mean().item())
            else:
                jacnorm_reg = softplus(jacnorm - rho)
                loss = recon - alpha * vol.mean() + args.lam_norm * jacnorm_reg.mean()

            # Backprop and update
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Record loss
            train_loss += loss.item()
            train_recon += recon.item()
            train_vol += vol.mean().item()
            train_jacnorm += jacnorm.mean().item()

        # Set rho after warmup
        if epoch == args.warmup:
            rho = np.mean(rho_history[-10:])
        
        # Log images to W&B
        if epoch == 0 or epoch % args.eval_iter == 0 or epoch == args.num_iters:
            model.eval()

            # Model checkpoint
            torch.save(model.encoder.state_dict(), os.path.join(ckpt_dir, f"encoder_{epoch}.pt"))
            torch.save(model.decoder.state_dict(), os.path.join(ckpt_dir, f"decoder_{epoch}.pt"))

            # Log training stats to W&Bk
            val_recon, val_vol, val_jacnorm = 0, 0, 0
            for x_val in val_loader:
                x_val = x_val[0].to(device)
                xh_val, sh_val = model(x_val)

                val_recon += ((x_val - xh_val).square().mean()).item()
                val_jac = vmap(jacfwd(model.decoder))(sh_val)
                val_vol += logdet(val_jac.transpose(-1, -2) @ val_jac
                                    + torch.eye(jac.shape[-1], device=jac.device).unsqueeze(0)).mean().item()
                val_jacnorm += val_jac.abs().sum(dim=(1, 2)).mean().item()


            # Log to W&B
            wandb.log({
                "train/loss": train_loss,
                "train/recon": train_recon,
                "val/recon": val_recon,
                "train/vol": train_vol,
                "val/vol": val_vol,
                "train/jacorm": train_jacnorm,
                "val/jacnorm": val_jacnorm,
                "train/epoch": epoch,
                "val/epoch": epoch,
                "train/alpha": alpha,
                "train/rho": rho,
            })

            for dim in range(args.latent_dim):
                for digit in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
                    # Get digit data
                    digit_loader = digit_loaders[digit]
                    digit_imgs = next(iter(digit_loader))[0].to(device)
                    latent_codes = model.encoder(digit_imgs)

                    # Get n_rows number of anchors
                    anchor_latents = latent_codes[:n_rows]

                    # Calculate standard deviation along each dimension
                    std = args.std_coeff*torch.std(latent_codes[:, dim])
                    
                    # Create linspace for both dimensions based on anchor ± std
                    list_values = [torch.linspace(
                        anchor_latents[i, dim].squeeze() - std,
                        anchor_latents[i, dim].squeeze() + std, 
                        n_steps
                    ) for i in range(n_rows)]
                    
                    # Initialize the grid to store generated images
                    grid = torch.zeros(n_rows, n_steps, args.channels, args.img_dim, args.img_dim)
                    grid = grid.to(device)
                    
                    # Generate images for each combination of dim1 and dim2 values
                    with torch.no_grad():
                        for i, values in enumerate(list_values):
                            for j, val in enumerate(values):
                                # Modify the latent vector
                                latent_modified = anchor_latents[i].clone()
                                # print(latent_modified.shape)
                                latent_modified[dim] = val
                                
                                # Generate image
                                generated = model.decoder(latent_modified)
                                
                                # Store in grid
                                grid[i, j] = generated[0]

                    # Plot the grid
                    fig = plot_latent_grid_2d(
                        grid,
                        dim1_name=f'dim{dim}',
                        dim2_name=f'dim{digit}',
                        figsize=(args.n_steps, args.n_rows))
                    wandb.log({f"dim_{dim}/digit_{digit}": fig})
                    plt.close('all')

    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--lam_vol",
        help="Volume regularization coefficient",
        type=float,
        default=1e-4
    )
    parser.add_argument(
        "--lam_norm",
        help="L1 norm regularization coefficient",
        type=float,
        default=1e-4
    )
    parser.add_argument(
        "--warmup",
        help="Warmup iterations",
        type=int,
        default=50
    )
    parser.add_argument(
        "--train_batchsize",
        type=int,
        default=1024
    )
    parser.add_argument(
        "--val_batchsize",
        type=int,
        default=64
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=1e-3
    )
    parser.add_argument(
        "--n_rows",
        help="Number of rows in the grid",
        type=int,
        default=10
    )
    parser.add_argument(
        "--n_steps",
        help="Number of steps in the grid",
        type=int,
        default=10
    )
    parser.add_argument(
        "--num_iters",
        help="Number of training iter",
        type=int,
        default=100,
    )
    parser.add_argument(
        "--eval_iter",
        help="Evaluating every no. of iterations given by arg",
        type=int,
        default=10,
    )
    parser.add_argument(
        "--latent_dim",
        help="Latent dim",
        type=int,
        default=10
    )
    parser.add_argument(
        "--img_dim",
        help="Image dimension",
        type=int,
        default=32
    )
    parser.add_argument(
        "--channels",
        help="Number of channels",
        type=int,
        default=1
    )
    parser.add_argument(
        "--cuda",
        help="CUDA device",
        type=int,
        default=0
    )
    parser.add_argument(
        "--std_coeff",
        help="Std coeff for varying latents",
        type=float,
        default=4.0
    )
    args = parser.parse_args()
    
    # W&B init
    wandb.init(project=f"dica_mnist", config=vars(args))

    # Train model
    train(args)

    # W&B summary and finish
    wandb.finish()