import argparse
import math
import random
from pathlib import Path

import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import pandas as pd
import matplotlib.pyplot as plt
import yaml
import wandb

from model import MultiHeadCLSEDM
from utils import get_date, cycle, init_wandb, log_wandb, finish_wandb

def main(args):

    # Set random seeds for reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # get data
    date = get_date()

    # device
    device = torch.device(args.device)

    # Create work directory if it doesn't exist
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Wandb
    wb = init_wandb(
            project=getattr(args, "wandb_project", "diffcls"),
            name=getattr(args, "wandb_run", "multihead_clsedm") + f"_{date}",
            mode=getattr(args, "wandb", None),  
            config=vars(args),
            dir_=str(output_dir),
    )

    with open(args.config_path, "r") as f:
        config = yaml.safe_load(f)
    
    model = MultiHeadCLSEDM(**config).to(args.device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Dataset
    if args.data_name == 'mnist':
        assert args.num_classes == 10
        transform = transforms.Compose([
            transforms.Pad(padding=2), # 28x28 → 32x32
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)) # [0, 1] → [-1, 1]
        ])
        ds = datasets.MNIST(root='./dataset', train=True, download=True, transform=transform)
    elif args.data_name == 'cifar10':
        assert args.num_classes == 10
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)) # [0, 1] → [-1, 1]
        ])
        ds = datasets.CIFAR10(root='./dataset', train=True, download=True, transform=transform)
    else:
        raise ValueError(f"Unsupported dataset: {args.data_name}")
    
    dl = DataLoader(ds, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
    dl = cycle(dl)

    model.train() 
    loss_history = []
    for step in range(args.train_iterations):

        # Sample a batch of data
        x, y = next(dl)
        x = x.to(device)
        y = y.to(device)
        y_onehot = F.one_hot(y, num_classes=args.num_classes)

        # loss calculation
        loss_ce, loss_reg = model.loss_func(x, y_onehot)
        loss = loss_ce + loss_reg * args.lambda_reg
        
        # Update weights
        optimizer.zero_grad()
        loss.backward()

        # Learning rate warmup & grad sanitization
        for g in optimizer.param_groups:
            g['lr'] = args.lr * min(step / args.warmup, 1)
        for param in model.net.parameters():
            if param.grad is not None:
                torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
        
        optimizer.step()

        # EMA update
        model.update_ema(args.train_iterations, args.batch_size, **config)

        # Log wandb
        log_wandb(wb, {"loss": loss.item(), "loss_ce": loss_ce.item(), "loss_reg": loss_reg.item()}, step)

        # Learning rate warmup & grad sanitization
        if step % args.freq_log == 0:
            print(f"Step {step} / {args.train_iterations}, Loss: {loss.item()}, Loss CE: {loss_ce.item()}, Loss Reg: {loss_reg.item()}")
            loss_history.append(loss.item())
        
        if step % args.freq_model_save == 0:
            torch.save(model.state_dict(),  output_dir / f"model_{step}.pt")
        
        if step % args.freq_image_save == 0:
            model.eval()
            with torch.no_grad():
                y = torch.arange(0, args.num_classes).to(device)
                y = y.repeat(args.num_samples // args.num_classes + 1)[:args.num_samples]
                y_onehot = F.one_hot(y, num_classes=args.num_classes)
                fakes = model.inference(y_onehot)
                fakes = (fakes + 1) / 2  # [-1, 1] -> [0, 1]
                fakes = fakes.clamp(0, 1)

                grid_fake = make_grid(fakes.detach().cpu(), nrow=int(math.sqrt(args.num_samples)))
                log_wandb(wb, {
                        "sample": wandb.Image(grid_fake, caption=f"step {step}")
                    }, step=step)

            model.train()
        
    torch.save(model.state_dict(),  output_dir / "model.pt")

    df = pd.DataFrame(loss_history)
    df.to_csv(output_dir / "loss_history.csv", index=False)

    plt.plot(loss_history)
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.title("Training Loss")
    plt.grid(True)
    plt.savefig(output_dir / "loss_curve.png") 

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument("--seed", type=int, default=42)

    parser.add_argument("--data_name", type=str, default="mnist")
    parser.add_argument("--output_dir", type=str, default="./expr/mnist/multihead_clsedm")
    parser.add_argument("--config_path", type=str, default="./multihead_edm/config_mnist.yaml")

    parser.add_argument('--train_iterations', type=int, default=200000)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--warmup', type=int, default=5000)
    parser.add_argument("--num_classes", type=int, default=10)
    parser.add_argument('--lambda_reg', type=float, default=1.0)

    parser.add_argument('--freq_log', type=int, default=100)
    parser.add_argument('--freq_model_save', type=int, default=50000)
    parser.add_argument('--freq_image_save', type=int, default=10000)
    parser.add_argument('--num_samples', type=int, default=16)

    parser.add_argument("--wandb", help="optional wandb (default: off)", choices=["on", "off", "offline"], default="off", type=str)
    parser.add_argument("--wandb_project", help="wandb project name", default="diffcls", type=str)
    parser.add_argument("--wandb_run", help="wandb run name", default="multihead-clsedm", type=str)

    args = parser.parse_args()

    main(args)