import argparse
from pathlib import Path
import yaml
import time
import shutil
import numpy as np
from typing import Tuple

# from efficient_kan import KAN
from fastkan import FastKAN as KAN

# Train on CIFAR-10
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.v2 as T
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import math

# Parse arguments
parser = argparse.ArgumentParser(description="Train FastKAN on CIFAR-10")
parser.add_argument("--path_config", type=Path, required=True,
                    help="YAML configuration file")
parser.add_argument("--path_save", type=Path, required=True,
                    help="Directory where checkpoints & summary will be "
                    "written")
args = parser.parse_args()

# Load config
with open(args.path_config, "r") as f:
    cfg = yaml.safe_load(f)

# Augmentation & training hyper-parameters
MIXUP_ALPHA = float(cfg["MIXUP_ALPHA"])
CUTMIX_BETA = float(cfg["CUTMIX_BETA"])
CUTMIX_PROB = float(cfg["CUTMIX_PROB"])

NUM_EPOCHS = int(cfg["NUM_EPOCHS"])
NUM_GRIDS = int(cfg["NUM_GRIDS"])

# Learning-rate schedule hyper-parameters (mandatory)
BASE_LR = float(cfg["LR"])  # Peak LR after warm-up
WARMUP_EPOCHS = int(cfg["WARMUP_EPOCHS"])  # Warm-up duration

# Setup save directory
PATH_SAVE = args.path_save
PATH_SAVE.mkdir(parents=True, exist_ok=True)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42  # For reproducible split

# Define CPU transforms (just convert to tensor)
cpu_to_tensor = transforms.ToTensor()

# Define GPU transforms
MEAN: Tuple[float, float, float] = (0.4914, 0.4822, 0.4465)
STD: Tuple[float, float, float]  = (0.2470, 0.2435, 0.2616)

train_gpu_tf = nn.Sequential(
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.3, 0.3, 0.3, 0.05),
    T.RandAugment(2, 7),
    T.RandomErasing(p=0.25, scale=(0.05, 0.2), ratio=(0.3, 3.3)),
    T.Normalize(MEAN, STD),
).to(device)

val_gpu_tf = nn.Sequential(
    T.Normalize(MEAN, STD)
).to(device)


# Create data loaders with train/val split
def make_loaders(batch_size=64):
    """
    Stratified 45 000 / 5 000 split so every class keeps the same ratio.
    A deterministic RNG (SEED) is used so the split is reproducible.
    """
    full_train = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True, transform=cpu_to_tensor
    )

    targets = torch.tensor(full_train.targets)
    num_classes = targets.max().item() + 1          # 10 for CIFAR-10
    n_val_per_class = 5_000 // num_classes          # 500 images each

    g = torch.Generator().manual_seed(SEED)
    val_idx, train_idx = [], []

    for c in range(num_classes):
        cls_indices = torch.nonzero(targets == c, as_tuple=False).view(-1)
        cls_indices = cls_indices[torch.randperm(
            len(cls_indices), generator=g)]
        val_idx.extend(cls_indices[:n_val_per_class].tolist())
        train_idx.extend(cls_indices[n_val_per_class:].tolist())

    train_set = Subset(full_train, train_idx)
    val_set = Subset(full_train, val_idx)

    test_set = torchvision.datasets.CIFAR10(
        root="./data", train=False, download=True, transform=cpu_to_tensor
    )

    train_loader = DataLoader(
        train_set, batch_size=batch_size, shuffle=True,
        num_workers=4, pin_memory=True
    )
    val_loader = DataLoader(
        val_set, batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True
    )
    test_loader = DataLoader(
        test_set, batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True
    )
    return train_loader, val_loader, test_loader


trainloader, valloader, testloader = make_loaders(64)
print(f"Dataset sizes → train: {len(trainloader.dataset)}, "
      f"val: {len(valloader.dataset)}, test: {len(testloader.dataset)}")


# MixUp / CutMix helper
def mixup_cutmix(x, y):
    bs = x.size(0)
    idx = torch.randperm(bs, device=device)
    y2 = y[idx]
    if torch.rand(1, device=device) < CUTMIX_PROB:
        lam = torch.distributions.Beta(
            CUTMIX_BETA, CUTMIX_BETA).sample().item()
        H, W = 32, 32  # CIFAR-10 image dimensions
        c = 3  # channels
        cut_rat = math.sqrt(1 - lam)
        cw, ch = int(W * cut_rat), int(H * cut_rat)
        cx = torch.randint(W, (1,), device=device).item()
        cy = torch.randint(H, (1,), device=device).item()
        x1, x2 = max(cx - cw // 2, 0), min(cx + cw // 2, W)
        y1, y2p = max(cy - ch // 2, 0), min(cy + ch // 2, H)

        # Reshape for 2D operations
        x_2d = x.view(-1, c, H, W)
        x_idx_2d = x[idx].view(-1, c, H, W)

        x_2d[:, :, y1:y2p, x1:x2] = x_idx_2d[:, :, y1:y2p, x1:x2]
        x = x_2d.view(-1, c * H * W)

        lam = 1 - ((x2 - x1) * (y2p - y1) / (W * H))
    else:
        lam = torch.distributions.Beta(
            MIXUP_ALPHA, MIXUP_ALPHA).sample().item()
        x = lam * x + (1 - lam) * x[idx]
    return x, y, y2, lam


# Count model parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Define model
model = KAN([3072, 256, 10], num_grids=NUM_GRIDS)
model.to(device)

# Count and print the number of parameters
n_params = count_parameters(model)
print(f"Model: FastKAN, Number of parameters: {n_params:,}")

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-4)

# Learning-rate scheduler: linear warm-up → cosine decay to 0


def lr_lambda(current_epoch: int):
    """Piecewise LR: linear warm-up, then cosine decay to zero."""
    if current_epoch < WARMUP_EPOCHS:
        return float(current_epoch + 1) / max(1, WARMUP_EPOCHS)

    progress = current_epoch - WARMUP_EPOCHS
    total = max(1, NUM_EPOCHS - WARMUP_EPOCHS)
    if total <= 1:
        return 0.0
    cosine_progress = progress / (total - 1)
    return 0.5 * (1.0 + math.cos(math.pi * cosine_progress))

# last_epoch=-1 so the first call sets LR for epoch 0
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# Define loss function
criterion = nn.CrossEntropyLoss()


# Training and evaluation functions
def train_epoch():
    model.train()
    train_acc = 0.0
    with tqdm(trainloader) as pbar:
        for i, (images_cpu, labels) in enumerate(pbar):
            # Move to GPU and apply augmentations
            images = images_cpu.to(device, non_blocking=True)
            labels = labels.to(device)

            # Apply GPU-based augmentations
            images = train_gpu_tf(images)

            # Flatten for KAN
            images = images.view(-1, 3072)

            # Apply mixup/cutmix
            images, y1, y2, lam = mixup_cutmix(images, labels)

            optimizer.zero_grad()
            output = model(images)

            # Use CrossEntropy with mixup/cutmix
            loss = lam * F.cross_entropy(output, y1)
            loss += (1 - lam) * F.cross_entropy(output, y2)

            loss.backward()
            optimizer.step()

            # Display accuracy based on original labels
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            train_acc += accuracy.item()
            pbar.set_postfix(
                loss=loss.item(),
                accuracy=accuracy.item(),
                lr=optimizer.param_groups[0]['lr']
            )
    return train_acc / len(trainloader)


def eval_loader(loader):
    """Run inference on the given loader and return accuracy (percentage)."""
    model.eval()
    acc_sum = 0.0
    loss_sum = 0.0
    with torch.no_grad():
        for images_cpu, labels in loader:
            # Move to GPU and apply normalization
            images = images_cpu.to(device, non_blocking=True)
            labels = labels.to(device)

            # Apply GPU-based normalization
            images = val_gpu_tf(images)

            # Flatten for KAN
            images = images.view(-1, 3072)

            output = model(images)
            loss_sum += criterion(output, labels).item()
            acc_sum += (output.argmax(dim=1) == labels).float().mean().item()

    return loss_sum / len(loader), acc_sum / len(loader) * 100


def main():
    start_time = time.time()
    best_val_acc = 0.0
    best_state = None
    best_epoch = -1

    # Lists to store logs
    log_epochs = []
    log_lrs = []
    log_train_accs = []
    log_val_accs = []

    epoch_bar = tqdm(range(NUM_EPOCHS), desc="Epochs")
    for epoch in epoch_bar:
        # Update LR at the start of the epoch
        scheduler.step()

        # Train
        train_acc = train_epoch()

        # Validation
        val_loss, val_acc = eval_loader(valloader)
        current_lr = optimizer.param_groups[0]['lr']

        # Log metrics
        log_epochs.append(epoch + 1)
        log_lrs.append(current_lr)
        log_train_accs.append(train_acc * 100)  # Convert to percentage
        log_val_accs.append(val_acc)

        epoch_bar.set_postfix(
            train=f"{train_acc*100:.2f}%",
            val=f"{val_acc:.2f}%",
            lr=f"{current_lr:.1e}"
        )

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch + 1
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}

    # Save final weights
    torch.save(model.cpu().state_dict(), PATH_SAVE / "final_model.pt")

    # Evaluate best model on test set
    if best_state:
        torch.save(best_state, PATH_SAVE / "best_model.pt")
        model.load_state_dict(best_state)
        model.to(device)

    test_loss, test_acc = eval_loader(testloader)
    elapsed = time.time() - start_time

    # Save logs to npz file
    np.savez(
        PATH_SAVE / "logs.npz",
        epoch=np.array(log_epochs),
        lr=np.array(log_lrs),
        train_acc=np.array(log_train_accs),
        val_acc=np.array(log_val_accs)
    )

    # Write summary
    with open(PATH_SAVE / "summary.txt", "w") as f:
        f.write("Model: FastKAN\n")
        f.write(f"Number of parameters: {n_params:,}\n")
        f.write(f"Best validation accuracy: {best_val_acc:.2f}% "
                f"(Epoch {best_epoch})\n")
        f.write(f"Test accuracy: {test_acc:.2f}%\n")
        f.write(f"Elapsed time (s): {elapsed:.1f}\n")

    # Copy the config file used for this run
    shutil.copy(args.path_config, PATH_SAVE / "used_config.yaml")

    print(
        f"Best val acc: {best_val_acc:.2f}% | Test acc: {test_acc:.2f}% | "
        f"Elapsed: {elapsed:.1f}s.\n"
        f"Weights and summary saved to {PATH_SAVE}"
    )


if __name__ == "__main__":
    main()
