#!/usr/bin/env python3
import os
import hydra
import torch
import numpy as np
import logging
import torch.nn as nn

from omegaconf import DictConfig, OmegaConf
from diffusers import DDIMScheduler
from torch.utils.data import DataLoader
from torch import Tensor
from torch.nn.utils import clip_grad_norm_
from torch.optim import AdamW
from torch.optim.lr_scheduler import ExponentialLR
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter

from encoder import TimeConditionedEncoder
from data import get_dataset
from spectral import compute_eigenvalues, whitening
from utils import save_eigenvalue_plot

from typing import Tuple

log = logging.getLogger("train")

def train_one_epoch(
    x: Tensor,
    t: Tensor,
    encoder: nn.Module,
    noise_scheduler: DDIMScheduler,
    num_chunks: int,
    optimizer,
    grad_clip: float,
    eps: float,
    ridge: float,
) -> Tuple[Tensor, Tensor]:
    """
    """
    noise_a = torch.randn_like(x)
    noise_b = torch.randn_like(x)
    xt_a = noise_scheduler.add_noise(x, noise_a, t)
    xt_b = noise_scheduler.add_noise(x, noise_b, t)
    xt_a_chunks = torch.chunk(xt_a, num_chunks, dim=0)
    xt_b_chunks = torch.chunk(xt_b, num_chunks, dim=0)
    t_batch_chunks = torch.chunk(t, num_chunks, dim=0)

    # Compute phi_a without grad
    phi_a_chunks = []
    for (xt_a_chunk, t_batch_chunk) in zip(xt_a_chunks, t_batch_chunks):
        with torch.no_grad():
            phi_a_chunk = encoder(xt_a_chunk, t_batch_chunk)
        phi_a_chunks.append(phi_a_chunk)
    phi_a = torch.cat(phi_a_chunks)

    # Estimate mean and covariance on full phi_a (no grad)
    phi_a_mu, phi_a_whitener = whitening(phi_a, eps=eps, ridge=ridge)

    phi_b = []
    for (phi_a_chunk, xt_b_chunk, t_batch_chunk) in zip(phi_a_chunks, xt_b_chunks, t_batch_chunks):
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            phi_b_chunk = encoder(xt_b_chunk, t_batch_chunk)
            phi_b.append(phi_b_chunk.detach())

        eigenvalues = compute_eigenvalues(
            phi_a_chunk.float(), # no-grad
            phi_b_chunk.float(),
            mu=phi_a_mu, # no-grad
            whitener=phi_a_whitener, # no-grad
            normalize=True,
        )
        
        loss = (1.0 - eigenvalues.mean()) / num_chunks
        loss.backward()

    phi_b = torch.cat(phi_b)

    if grad_clip is not None:
        clip_grad_norm_(encoder.parameters(), max_norm=grad_clip)

    optimizer.step()
    optimizer.zero_grad()       
    
    eigenvalues = compute_eigenvalues(
        phi_a.float().detach(),
        phi_b.float().detach(),
        mu=phi_a_mu,
        whitener=phi_a_whitener,
        normalize=True,
    )
    
    return loss, eigenvalues


@hydra.main(config_path="conf", config_name="config", version_base="1.3")
def main(cfg: DictConfig):
    log.info("Starting training")
    log.info("Config:\n" + OmegaConf.to_yaml(cfg))

    writer = SummaryWriter(".")

    dataset = get_dataset(
        dataset=cfg.dataset.name,
        split="train",
        augment=True,
        cache_dir=cfg.dataset.cache_dir,
    )

    train_loader = DataLoader(
        dataset,
        batch_size=cfg.training.batch_size,
        pin_memory=cfg.training.pin_memory,
        num_workers=cfg.training.num_workers,
        persistent_workers=cfg.training.persistent_workers,
        shuffle=True,
        collate_fn=getattr(dataset, "collate_fn", None),
    )
    
    phi_encoder = TimeConditionedEncoder(
        image_size=cfg.dataset.image_size,
        out_dim=cfg.model.num_eigenfunctions,
        time_emb_dim=cfg.model.time_emb_dim,
        base_channels=cfg.model.base_channels,
        channel_mults=cfg.model.channel_mults,
        min_resolution=cfg.model.min_resolution,
        max_channels=cfg.model.max_channels,
        num_train_timesteps=cfg.scheduler.num_train_timesteps,
    )
    phi_encoder = phi_encoder.to(cfg.device)

    if cfg.scheduler.pretrained is not None:
        noise_scheduler = DDIMScheduler.from_pretrained(cfg.scheduler.pretrained)
    else:
        noise_scheduler = DDIMScheduler(
            num_train_timesteps=cfg.scheduler.num_train_timesteps,
            beta_start=cfg.scheduler.beta_start,
            beta_end=cfg.scheduler.beta_end,
            beta_schedule=cfg.scheduler.beta_schedule,
        )

    optimizer = AdamW(
        params=phi_encoder.parameters(),
        lr=cfg.training.lr,
        weight_decay=cfg.training.weight_decay,
    )

    scheduler = ExponentialLR(optimizer, gamma=cfg.training.gamma)

    start_epoch = 0
    eigenvalues_t = {}

    if cfg.training.checkpoint_path is not None and os.path.exists(cfg.training.checkpoint_path):
        log.info(f"Loading checkpoint from {cfg.training.checkpoint_path}")
        checkpoint = torch.load(cfg.training.checkpoint_path, map_location=cfg.device, weights_only=False)
        phi_encoder.load_state_dict(checkpoint["phi_encoder"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        eigenvalues_t = checkpoint.get("eigenvalues_t", {})
        eigenvalues_t = {k : v.cpu() for k,v in eigenvalues_t.items()}
        start_epoch = checkpoint.get("epoch", 0) + 1

    for epoch in range(start_epoch, cfg.training.epochs):
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}", colour="blue")
        for batch in pbar:
            if isinstance(batch, (list, tuple)):
                x = batch[0].to(cfg.device)
            else:
                x = batch.to(cfg.device)
            batch_size = x.size(0)

            t = np.random.choice(range(cfg.scheduler.start, cfg.scheduler.num_train_timesteps, cfg.scheduler.step))
            t_batch = torch.full((batch_size,), t, device=cfg.device) 

            loss, eigenvalues = train_one_epoch(
                x,
                t_batch,
                encoder=phi_encoder,
                noise_scheduler=noise_scheduler,
                num_chunks=cfg.training.num_chunks,
                optimizer=optimizer,
                grad_clip=cfg.training.grad_clip,
                eps=cfg.training.eps,
                ridge=cfg.training.ridge,
            )

            eigenvalues_t[t] = eigenvalues.cpu()              
  
            pbar.set_postfix({
                "psi_loss" : loss.item(),
                "t" : t,
                "lr" : scheduler.get_last_lr()[0],
            })
            
        scheduler.step()
        writer.add_scalar(f"train/lr", scheduler.get_last_lr()[0], epoch)
        save_eigenvalue_plot(eigenvalues_t, "evals.png", epoch, writer)
        torch.save(phi_encoder.state_dict(), "state_dict.pt")
        torch.save({
            "phi_encoder": phi_encoder.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "epoch": epoch,
            "eigenvalues_t" : eigenvalues_t,
            "config": OmegaConf.to_container(cfg, resolve=True),
        }, "checkpoint.pt")
        log.info("Saved checkpoint: state_dict.pt")

    writer.close()

if __name__ == "__main__":
    main()