import argparse
from pathlib import Path
import yaml
import time
import shutil
import numpy as np
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.v2 as T
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import math

# Import the DKAN layer
from ldKAN.dkan_2d import DKAN_2D_Layer

# Parse arguments
parser = argparse.ArgumentParser(description="Train DKAN on CIFAR-10")
parser.add_argument("--path_config", type=Path, required=True,
                    help="YAML configuration file")
parser.add_argument("--path_save", type=Path, required=True,
                    help="Directory where checkpoints & summary will be "
                    "written")
args = parser.parse_args()

# Load config
with open(args.path_config, "r") as f:
    cfg = yaml.safe_load(f)

# Get augmentation parameters from config
MIXUP_ALPHA = float(cfg["MIXUP_ALPHA"])
CUTMIX_BETA = float(cfg["CUTMIX_BETA"])
CUTMIX_PROB = float(cfg["CUTMIX_PROB"])

# DKAN-specific parameters
INIT_SCALE = float(cfg["INIT_SCALE"])
N_CHUNKS = int(cfg["N_CHUNKS"])
BLOCK_SIZE_FORWARD = int(cfg["BLOCK_SIZE_FORWARD"])
BLOCK_SIZE_BACKWARD = int(cfg["BLOCK_SIZE_BACKWARD"])
TILE_SIZE_FORWARD = int(cfg["TILE_SIZE_FORWARD"])
TILE_SIZE_BACKWARD = int(cfg["TILE_SIZE_BACKWARD"])

# Training parameters
BATCH_SIZE = int(cfg["BATCH_SIZE"])

# Multi-stage training parameters
# Note: Use EPOCHS in parameter names instead of STEPS
PURE_MLP_EPOCHS = int(cfg["PURE_MLP_EPOCHS"])
DKAN_TURN_ON_EPOCHS = int(cfg["DKAN_TURN_ON_EPOCHS"])
DKAN_TURN_ON_SCALE = int(cfg["DKAN_TURN_ON_SCALE"])
DKAN_TURN_ON_CAP = float(cfg["DKAN_TURN_ON_CAP"])
DKAN_FROBENIUS_DECAY_EPOCHS = int(cfg["DKAN_FROBENIUS_DECAY_EPOCHS"])
DKAN_FROBENIUS_DECAY_SCALE = int(cfg["DKAN_FROBENIUS_DECAY_SCALE"])
FROBENIUS_WEIGHT_CAP = float(cfg["FROBENIUS_WEIGHT_CAP"])
DKAN_LEARNING_RATE_DECAY_EPOCHS = int(cfg["DKAN_LEARNING_RATE_DECAY_EPOCHS"])
DKAN_LEARNING_RATE_DECAY_SCALE = int(cfg["DKAN_LEARNING_RATE_DECAY_SCALE"])
INITIAL_FROBENIUS_WEIGHT = float(cfg["INITIAL_FROBENIUS_WEIGHT"])
DKAN_BASE_LR = float(cfg["DKAN_BASE_LR"])

# Setup save directory
PATH_SAVE = args.path_save
PATH_SAVE.mkdir(parents=True, exist_ok=True)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42  # For reproducible split

# Define CPU transforms (just convert to tensor)
cpu_to_tensor = transforms.ToTensor()

# Define GPU transforms
MEAN: Tuple[float, float, float] = (0.4914, 0.4822, 0.4465)
STD: Tuple[float, float, float]  = (0.2470, 0.2435, 0.2616)

train_gpu_tf = nn.Sequential(
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.3, 0.3, 0.3, 0.05),
    T.RandAugment(2, 7),
    T.RandomErasing(p=0.25, scale=(0.05, 0.2), ratio=(0.3, 3.3)),
    T.Normalize(MEAN, STD),
).to(device)

val_gpu_tf = nn.Sequential(
    T.Normalize(MEAN, STD)
).to(device)


# Create data loaders with train/val split
def make_loaders(batch_size):
    """
    Stratified 45 000 / 5 000 split so every class keeps the same ratio.
    A deterministic RNG (SEED) is used so the split is reproducible.
    """
    full_train = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True, transform=cpu_to_tensor
    )

    targets = torch.tensor(full_train.targets)
    num_classes = targets.max().item() + 1          # 10 for CIFAR-10
    n_val_per_class = 5_000 // num_classes          # 500 images each

    g = torch.Generator().manual_seed(SEED)
    val_idx, train_idx = [], []

    for c in range(num_classes):
        cls_indices = torch.nonzero(targets == c, as_tuple=False).view(-1)
        cls_indices = cls_indices[torch.randperm(
            len(cls_indices), generator=g)]
        val_idx.extend(cls_indices[:n_val_per_class].tolist())
        train_idx.extend(cls_indices[n_val_per_class:].tolist())

    train_set = Subset(full_train, train_idx)
    val_set = Subset(full_train, val_idx)

    test_set = torchvision.datasets.CIFAR10(
        root="./data", train=False, download=True, transform=cpu_to_tensor
    )

    train_loader = DataLoader(
        train_set, batch_size=batch_size, shuffle=True,
        num_workers=4, pin_memory=True
    )
    val_loader = DataLoader(
        val_set, batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True
    )
    test_loader = DataLoader(
        test_set, batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True
    )
    return train_loader, val_loader, test_loader


trainloader, valloader, testloader = make_loaders(BATCH_SIZE)
print(f"Dataset sizes → train: {len(trainloader.dataset)}, "
      f"val: {len(valloader.dataset)}, test: {len(testloader.dataset)}")


# MixUp / CutMix helper
def mixup_cutmix(x, y):
    bs = x.size(0)
    idx = torch.randperm(bs, device=device)
    y2 = y[idx]
    if torch.rand(1, device=device) < CUTMIX_PROB:
        lam = torch.distributions.Beta(
            CUTMIX_BETA, CUTMIX_BETA).sample().item()
        H, W = 32, 32  # CIFAR-10 image dimensions
        c = 3  # channels
        cut_rat = math.sqrt(1 - lam)
        cw, ch = int(W * cut_rat), int(H * cut_rat)
        cx = torch.randint(W, (1,), device=device).item()
        cy = torch.randint(H, (1,), device=device).item()
        x1, x2 = max(cx - cw // 2, 0), min(cx + cw // 2, W)
        y1, y2p = max(cy - ch // 2, 0), min(cy + ch // 2, H)

        # Reshape to (batch, channels, height, width) for CutMix
        x_2d = x.view(bs, c, H, W)
        x_idx_2d = x[idx].view(bs, c, H, W)

        x_2d[:, :, y1:y2p, x1:x2] = x_idx_2d[:, :, y1:y2p, x1:x2]
        x = x_2d.view(bs, c * H * W)

        lam = 1 - ((x2 - x1) * (y2p - y1) / (W * H))
    else:
        lam = torch.distributions.Beta(
            MIXUP_ALPHA, MIXUP_ALPHA).sample().item()
        x = lam * x + (1 - lam) * x[idx]

    return x, y, y2, lam


# DKAN model definition
class DKAN(nn.Module):
    def __init__(
        self,
        input_dim=3072,
        hidden_dim=256,
        output_dim=10,
        n_chunks=4,
        block_size_forward=16,
        block_size_backward=16,
        tile_size_forward=16,
        tile_size_backward=16,
        init_scale=1.0
    ):
        super(DKAN, self).__init__()

        # Store dimensions
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        # Calculate padded output dimension (multiple of tile_size_forward)
        self.padded_output_dim = ((output_dim - 1) // tile_size_forward + 1) * tile_size_forward

        self.bn1 = nn.BatchNorm1d(hidden_dim, affine=False)

        # First layer
        self.fc1 = DKAN_2D_Layer(
            n_chunks, input_dim, hidden_dim,
            block_size_forward, block_size_backward,
            tile_size_forward, tile_size_backward,
            False, False, True, False, init_scale, True
        )

        # Output layer - use padded output dimension
        self.fc2 = DKAN_2D_Layer(
            n_chunks, hidden_dim, self.padded_output_dim,
            block_size_forward, block_size_backward,
            tile_size_forward, tile_size_backward,
            False, False, True, False, init_scale, True
        )

    def forward(self, x, dkan_weight=1.0):
        # x is expected to be in batch-last format (features, batch)
        x = self.fc1(x, dkan_weight, False, relu_last=False)

        # Apply batch norm - need to transpose to batch-first temporarily
        x_bf = x.transpose(0, 1)  # Convert to batch-first
        x_bf = self.bn1(x_bf)
        x = x_bf.transpose(0, 1)  # Back to batch-last

        # Second layer
        x = self.fc2(x, dkan_weight, True, relu_last=False)

        # Final output - transpose to batch-first & slice to actual output_dim
        x = x.transpose(0, 1)[:, :self.output_dim]

        return x

    def get_frobenius_regularization(self):
        reg = (self.fc1.get_frobenius_regularization() +
               self.fc2.get_frobenius_regularization())
        return reg


# Count model parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Define model
model = DKAN(
    input_dim=3072,
    hidden_dim=256,
    output_dim=10,
    n_chunks=N_CHUNKS,
    block_size_forward=BLOCK_SIZE_FORWARD,
    block_size_backward=BLOCK_SIZE_BACKWARD,
    tile_size_forward=TILE_SIZE_FORWARD,
    tile_size_backward=TILE_SIZE_BACKWARD,
    init_scale=INIT_SCALE
)
model.to(device)

# Count and print the number of parameters
n_params = count_parameters(model)
print(f"Model: DKAN, Number of parameters: {n_params:,}")

# Define optimizer and criterion
criterion = nn.CrossEntropyLoss()

# Calculate total epochs for the training schedule
total_epochs = (
    PURE_MLP_EPOCHS
    + DKAN_TURN_ON_EPOCHS
    + DKAN_FROBENIUS_DECAY_EPOCHS
    + DKAN_LEARNING_RATE_DECAY_EPOCHS
)

print(f"Total training epochs: {total_epochs}")


# Define the multi-stage training parameters
def get_params(epoch):
    """Get training parameters based on current epoch.
    Steps are now interpreted as epochs directly."""
    if epoch < PURE_MLP_EPOCHS:
        lr = 1e-3
        dkan_weight = 0.0
        frobenius_weight = INITIAL_FROBENIUS_WEIGHT
    elif epoch < PURE_MLP_EPOCHS + DKAN_TURN_ON_EPOCHS:
        offset = epoch - PURE_MLP_EPOCHS
        lr = DKAN_BASE_LR
        dkan_weight = min(offset / DKAN_TURN_ON_SCALE, DKAN_TURN_ON_CAP)
        frobenius_weight = INITIAL_FROBENIUS_WEIGHT
    elif epoch < PURE_MLP_EPOCHS + DKAN_TURN_ON_EPOCHS + DKAN_FROBENIUS_DECAY_EPOCHS:
        offset = epoch - (PURE_MLP_EPOCHS + DKAN_TURN_ON_EPOCHS)
        lr = DKAN_BASE_LR
        dkan_weight = min(DKAN_TURN_ON_EPOCHS / DKAN_TURN_ON_SCALE,
                          DKAN_TURN_ON_CAP)
        frobenius_weight = INITIAL_FROBENIUS_WEIGHT / (
            10 ** (offset / DKAN_FROBENIUS_DECAY_SCALE))
        if frobenius_weight < FROBENIUS_WEIGHT_CAP:
            frobenius_weight = FROBENIUS_WEIGHT_CAP
    else:
        # Cosine LR decay phase
        offset = epoch - (
            PURE_MLP_EPOCHS + DKAN_TURN_ON_EPOCHS + DKAN_FROBENIUS_DECAY_EPOCHS
        )
        T = DKAN_LEARNING_RATE_DECAY_EPOCHS
        if T > 1:
            progress = offset / (T - 1)
            lr = 0.5 * DKAN_BASE_LR * (1.0 + math.cos(math.pi * progress))
        else:
            lr = 0.0

        dkan_weight = min(DKAN_TURN_ON_EPOCHS / DKAN_TURN_ON_SCALE,
                          DKAN_TURN_ON_CAP)

        frobenius_weight = (
            INITIAL_FROBENIUS_WEIGHT
            / (10 ** (DKAN_FROBENIUS_DECAY_EPOCHS / DKAN_FROBENIUS_DECAY_SCALE))
        )
        if frobenius_weight < FROBENIUS_WEIGHT_CAP:
            frobenius_weight = FROBENIUS_WEIGHT_CAP
    return lr, dkan_weight, frobenius_weight


# Initial parameters
init_lr, init_dkan_weight, init_frobenius_weight = get_params(0)
optimizer = optim.Adam(model.parameters(), lr=init_lr)


# Evaluation function
def eval_loader(loader, dkan_weight):
    """Run inference on the given loader and return accuracy (percentage)."""
    model.eval()
    acc_sum = 0.0
    loss_sum = 0.0
    with torch.no_grad():
        for images_cpu, labels in loader:
            # Move to GPU and apply normalization
            images = images_cpu.to(device, non_blocking=True)
            labels = labels.to(device)

            # Apply GPU-based normalization
            images = val_gpu_tf(images)

            # Flatten for DKAN (batch-first format)
            images = images.view(-1, 3072)  # (batch_size, 3072)

            # Right before inference, transpose to batch-last format
            images = images.transpose(0, 1).contiguous()  # (3072, batch_size)

            # Forward pass
            output = model(images, dkan_weight)

            loss_sum += criterion(output, labels).item()
            acc_sum += (output.argmax(dim=1) == labels).float().mean().item()

    return loss_sum / len(loader), acc_sum / len(loader) * 100


def main():
    start_time = time.time()
    best_val_acc = 0.0
    best_state = None
    best_epoch = -1

    # Lists to store logs
    log_epochs = []
    log_lrs = []
    log_dkan_weights = []
    log_frobenius_weights = []
    log_train_accs = []
    log_train_losses = []
    log_val_accs = []

    epoch_bar = tqdm(range(total_epochs), desc="Epochs")
    for epoch in epoch_bar:
        # Get current schedule parameters for this epoch
        lr, dkan_weight, frobenius_weight = get_params(epoch)

        # Update optimizer LR
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Train
        model.train()
        batch_bar = tqdm(trainloader, leave=False, desc=f"Epoch {epoch+1}")
        epoch_train_acc = 0.0
        epoch_train_loss = 0.0

        for images_cpu, labels in batch_bar:
            # Move data to device
            images = images_cpu.to(device, non_blocking=True)
            labels = labels.to(device)

            # Apply GPU-based augmentations
            images = train_gpu_tf(images)

            # Flatten for DKAN (batch-first format)
            images = images.view(-1, 3072)  # (batch_size, 3072)

            # Apply mixup/cutmix (still in batch-first format)
            images, y1, y2, lam = mixup_cutmix(images, labels)

            # Right before the forward pass, transpose to batch-last format
            images = images.transpose(0, 1).contiguous()  # (3072, batch_size)

            # Forward pass
            optimizer.zero_grad()
            output = model(images, dkan_weight)

            # Use CrossEntropy with mixup/cutmix
            pure_loss = lam * F.cross_entropy(output, y1)
            pure_loss += (1 - lam) * F.cross_entropy(output, y2)

            # Add Frobenius regularization
            fro_reg = model.get_frobenius_regularization()
            loss = pure_loss + frobenius_weight * fro_reg

            # Backward pass
            loss.backward()
            optimizer.step()

            # Calculate accuracy on original labels (not mixed)
            accuracy = (output.argmax(dim=1) == labels).float().mean()

            # Accumulate metrics
            epoch_train_acc += accuracy.item()
            epoch_train_loss += loss.item()

            # Update progress bar
            batch_bar.set_postfix(
                loss=f"{loss.item():.4f}",
                acc=f"{accuracy.item()*100:.2f}%",
                lr=f"{lr:.1e}",
                dkan=f"{dkan_weight:.3f}"
            )

        # Compute epoch averages
        epoch_train_acc /= len(batch_bar)
        epoch_train_loss /= len(batch_bar)

        # Validation
        val_loss, val_acc = eval_loader(valloader, dkan_weight)

        # Log epoch-level metrics
        log_epochs.append(epoch + 1)
        log_lrs.append(lr)
        log_dkan_weights.append(dkan_weight)
        log_frobenius_weights.append(frobenius_weight)
        log_train_accs.append(epoch_train_acc * 100)  # Convert to percentage
        log_train_losses.append(epoch_train_loss)
        log_val_accs.append(val_acc)

        epoch_bar.set_postfix(
            train=f"{epoch_train_acc*100:.2f}%",
            val=f"{val_acc:.2f}%",
            lr=f"{lr:.1e}",
            dkan=f"{dkan_weight:.3f}"
        )

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch + 1
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}

    # Save final weights
    torch.save(model.cpu().state_dict(), PATH_SAVE / "final_model.pt")

    # Evaluate best model on test set
    if best_state is not None:
        torch.save(best_state, PATH_SAVE / "best_model.pt")
        model.load_state_dict(best_state)
        model.to(device)

    # Get DKAN weight from best epoch for final evaluation
    _, best_dkan_weight, _ = get_params(best_epoch - 1)
    test_loss, test_acc = eval_loader(testloader, best_dkan_weight)
    elapsed = time.time() - start_time

    # Save logs to npz file
    np.savez(
        PATH_SAVE / "logs.npz",
        epochs=np.array(log_epochs),
        lr=np.array(log_lrs),
        dkan_weight=np.array(log_dkan_weights),
        frobenius_weight=np.array(log_frobenius_weights),
        train_acc=np.array(log_train_accs),
        train_loss=np.array(log_train_losses),
        val_acc=np.array(log_val_accs)
    )

    # Write summary
    with open(PATH_SAVE / "summary.txt", "w") as f:
        f.write("Model: DKAN\n")
        f.write(f"Number of parameters: {n_params:,}\n")
        f.write(f"Batch size: {BATCH_SIZE}\n")
        f.write(f"Best validation accuracy: {best_val_acc:.2f}% "
                f"(Epoch {best_epoch})\n")
        f.write(f"Test accuracy: {test_acc:.2f}%\n")
        f.write(f"DKAN weight at best epoch: {best_dkan_weight:.4f}\n")
        f.write(f"Total training epochs: {total_epochs}\n")
        f.write(f"Elapsed time (s): {elapsed:.1f}\n")

    # Copy the config file used for this run
    shutil.copy(args.path_config, PATH_SAVE / "used_config.yaml")

    print(
        f"Best val acc: {best_val_acc:.2f}% | Test acc: {test_acc:.2f}% | "
        f"Elapsed: {elapsed:.1f}s.\n"
        f"Weights and summary saved to {PATH_SAVE}"
    )


if __name__ == "__main__":
    main()