import argparse
import math
from pathlib import Path
from typing import Tuple, Dict
import time
import shutil
import yaml

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms.v2 as T
from tqdm import tqdm

from ldKAN.dkan_2d import DKAN_2D_Layer

# -------------------------------------------------------
# Argument parsing (requires config + save dir)
# -------------------------------------------------------
parser = argparse.ArgumentParser(description="DKAN 2×2-stride-2 CNN for CIFAR-10")
parser.add_argument("--path_config", type=Path, required=True,
                    help="YAML configuration file (e.g. ./config_dkan.yaml)")
parser.add_argument("--path_save", type=Path, required=True,
                    help="Directory where checkpoints & summary will be written")
args = parser.parse_args()

# -------------------------------------------------------
# Config loading (all required keys – KeyError if missing)
# -------------------------------------------------------
with open(args.path_config, "r") as f:
    cfg = yaml.safe_load(f)

# --- Dataset / augmentation hyper-params (shared with MLP version) ---
WIDTH: int = cfg["WIDTH"]                     # base width (divisible by TILE_SIZE_FORWARD)
SCALE_TYPE: str = cfg["SCALE_TYPE"]           # "progressive" | "constant"
BATCH_SIZE: int = cfg["BATCH_SIZE"]
MIXUP_ALPHA: float = float(cfg["MIXUP_ALPHA"])
CUTMIX_BETA: float = float(cfg["CUTMIX_BETA"])
CUTMIX_PROB: float = float(cfg["CUTMIX_PROB"])

# --- DKAN scheduling hyper-params ---
WARMUP_EPOCHS: int = cfg["WARMUP_EPOCHS"]
PURE_MLP_EPOCHS: int = cfg["PURE_MLP_EPOCHS"]
DKAN_TURN_ON_EPOCHS: int = cfg["DKAN_TURN_ON_EPOCHS"]
DKAN_TURN_ON_SCALE: int = cfg["DKAN_TURN_ON_SCALE"]
DKAN_TURN_ON_CAP: float = float(cfg["DKAN_TURN_ON_CAP"])
DKAN_FROBENIUS_DECAY_EPOCHS: int = cfg["DKAN_FROBENIUS_DECAY_EPOCHS"]
DKAN_FROBENIUS_DECAY_SCALE: int = cfg["DKAN_FROBENIUS_DECAY_SCALE"]
FROBENIUS_WEIGHT_CAP: float = float(cfg["FROBENIUS_WEIGHT_CAP"])
INITIAL_FROBENIUS_WEIGHT: float = float(cfg["INITIAL_FROBENIUS_WEIGHT"])
DKAN_BASE_LR: float = float(cfg["DKAN_BASE_LR"])
PURE_MLP_BASE_LR: float = float(cfg["PURE_MLP_BASE_LR"])

# --- Additional LR-decay hyper-params ---
DKAN_LEARNING_RATE_DECAY_EPOCHS: int = cfg["DKAN_LEARNING_RATE_DECAY_EPOCHS"]

# --- DKAN layer hyper-params ---
N_CHUNKS: int = cfg["N_CHUNKS"]
TILE_SIZE_FORWARD: int = cfg["TILE_SIZE_FORWARD"]
TILE_SIZE_BACKWARD: int = cfg["TILE_SIZE_BACKWARD"]
BLOCK_SIZE_FORWARD: int = cfg["BLOCK_SIZE_FORWARD"]
BLOCK_SIZE_BACKWARD: int = cfg["BLOCK_SIZE_BACKWARD"]
INIT_SCALE: float = float(cfg["INIT_SCALE"])

NUM_CLASSES: int = 10  # Fixed for CIFAR-10
PATCH_KERNEL_SIZE: int = 2  # 2×2 stride-2 patches
NUM_WORKERS: int = 4
# -------------------------------------------------------
# Paths / constants
# -------------------------------------------------------
PATH_SAVE: Path = args.path_save
PATH_SAVE.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device("cuda:0")
MEAN: Tuple[float, float, float] = (0.4914, 0.4822, 0.4465)
STD: Tuple[float, float, float] = (0.2470, 0.2435, 0.2616)
SEED = 42

# -------------------------------------------------------
# Data transforms (GPU)
# -------------------------------------------------------
cpu_to_tensor = T.ToTensor()

train_gpu_tf = nn.Sequential(
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.3, 0.3, 0.3, 0.05),
    T.RandAugment(2, 7),
    T.RandomErasing(p=0.25, scale=(0.05, 0.2), ratio=(0.3, 3.3)),
    T.Normalize(MEAN, STD),
).to(DEVICE)

val_gpu_tf = nn.Sequential(T.Normalize(MEAN, STD)).to(DEVICE)

# -------------------------------------------------------
# Data loaders
# -------------------------------------------------------

def make_loaders(batch_size: int):
    """Stratified 45 000 / 5 000 split so every class keeps the same ratio."""
    full_train = datasets.CIFAR10(
        root="./data", train=True, download=True, transform=cpu_to_tensor
    )

    targets = torch.as_tensor(full_train.targets)
    num_classes = targets.max().item() + 1          # 10 for CIFAR-10
    n_val_per_class = 5_000 // num_classes          # 500 images each

    g = torch.Generator().manual_seed(SEED)
    val_idx, train_idx = [], []

    for c in range(num_classes):
        cls_indices = torch.nonzero(targets == c, as_tuple=False).view(-1)
        cls_indices = cls_indices[torch.randperm(len(cls_indices), generator=g)]
        val_idx.extend(cls_indices[:n_val_per_class].tolist())
        train_idx.extend(cls_indices[n_val_per_class:].tolist())

    train_set = torch.utils.data.Subset(full_train, train_idx)
    val_set   = torch.utils.data.Subset(full_train, val_idx)

    test_set = datasets.CIFAR10(
        root="./data", train=False, download=True, transform=cpu_to_tensor
    )

    train_loader = DataLoader(
        train_set, batch_size=batch_size, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=True
    )
    val_loader = DataLoader(
        val_set, batch_size=batch_size, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=True
    )
    test_loader = DataLoader(
        test_set, batch_size=batch_size, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=True
    )
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = make_loaders(BATCH_SIZE)
print(f"Dataset sizes → train: {len(train_loader.dataset)}, val: {len(val_loader.dataset)}, test: {len(test_loader.dataset)}")

# -------------------------------------------------------
# DKAN building blocks
# -------------------------------------------------------

class PatchDKAN(nn.Module):
    """Implements a 2×2-stride-2 convolution with a single DKAN_2D_Layer."""
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.k = PATCH_KERNEL_SIZE
        self.in_ch = in_ch
        self.out_ch = out_ch

        patch_feat_dim = in_ch * self.k * self.k
        dkan_in_dim = ((patch_feat_dim - 1) // TILE_SIZE_FORWARD + 1) * TILE_SIZE_FORWARD
        self.input_padding = dkan_in_dim - patch_feat_dim

        # No output padding needed – out_ch is guaranteed divisible by TILE_SIZE_FORWARD
        dkan_out_dim = out_ch

        self.bn = nn.BatchNorm1d(dkan_out_dim, affine=False)
        self.dkan = DKAN_2D_Layer(
            n_chunks=N_CHUNKS,
            input_dim=dkan_in_dim,
            output_dim=dkan_out_dim,
            block_size_forward=BLOCK_SIZE_FORWARD,
            block_size_backward=BLOCK_SIZE_BACKWARD,
            tile_size_forward=TILE_SIZE_FORWARD,
            tile_size_backward=TILE_SIZE_BACKWARD,
            apply_scale=False,
            apply_bias=False,
            cdf_grid=True,
            apply_tanh=False,
            init_scale=INIT_SCALE,
            batch_last=True,
        )

    def forward(self, x: torch.Tensor, weight_dkan: float, apply_relu_linear: bool):
        b, c, h, w = x.shape
        assert h % self.k == 0 and w % self.k == 0, "H,W must be divisible by kernel"
        h2, w2 = h // self.k, w // self.k
        n_patches = h2 * w2

        # Unshuffle & flatten each patch
        x = F.pixel_unshuffle(x, self.k)  # (B, C*k*k, H/2, W/2)
        x = x.permute(0, 2, 3, 1).reshape(b * n_patches, -1)  # (B*n_patches, patch_feat_dim)
        if self.input_padding:
            x = F.pad(x, (0, self.input_padding))
        x = x.transpose(0, 1).contiguous()  # (dkan_in_dim, B*n_patches)

        x = self.dkan(x, weight_dkan=weight_dkan, apply_relu_linear=apply_relu_linear, relu_last=False)
        x = x.transpose(0, 1)  # (B*n_patches, out_ch)
        x = self.bn(x)
        x = x.reshape(b, n_patches, self.out_ch).transpose(1, 2).reshape(b, self.out_ch, h2, w2)
        return x

    def get_frobenius_regularization(self):
        return self.dkan.get_frobenius_regularization()


class SimpleCNNDKAN(nn.Module):
    """CNN with width-scaling using PatchDKAN layers and DKAN FC head."""
    def __init__(self, width: int, scale_type: str):
        super().__init__()
        layers = []
        in_ch, stage, size = 3, 0, 32
        out_ch = width
        patch_layers = []
        while True:
            patch_layers.append(PatchDKAN(in_ch, out_ch))
            stage += 1
            size //= 2
            if size == 1:
                break
            in_ch = out_ch
            out_ch = width * (2 ** stage) if scale_type == "progressive" else width
        self.patch_layers = nn.ModuleList(patch_layers)
        final_feat_dim = out_ch  # channels at 1×1 spatial

        # FC1 DKAN (relu_linear=True)
        self.dkan_fc1 = DKAN_2D_Layer(
            N_CHUNKS, final_feat_dim, final_feat_dim,
            BLOCK_SIZE_FORWARD, BLOCK_SIZE_BACKWARD,
            TILE_SIZE_FORWARD, TILE_SIZE_BACKWARD,
            False, False, True, False, INIT_SCALE, True
        )
        # FC2 DKAN (output padded to multiple of TILE_SIZE_FORWARD)
        dkan_fc2_out = ((NUM_CLASSES - 1) // TILE_SIZE_FORWARD + 1) * TILE_SIZE_FORWARD
        self.fc2_output_slicing = dkan_fc2_out - NUM_CLASSES
        self.bn_fc2 = nn.BatchNorm1d(final_feat_dim, affine=False)
        self.dkan_fc2 = DKAN_2D_Layer(
            N_CHUNKS, final_feat_dim, dkan_fc2_out,
            BLOCK_SIZE_FORWARD, BLOCK_SIZE_BACKWARD,
            TILE_SIZE_FORWARD, TILE_SIZE_BACKWARD,
            False, False, True, False, INIT_SCALE, True
        )

    def forward(self, x: torch.Tensor, weight_dkan: float):
        x = self.patch_layers[0](x, weight_dkan, apply_relu_linear=False)
        for layer in self.patch_layers[1:]:
            x = layer(x, weight_dkan, apply_relu_linear=True)
        x = torch.flatten(x, 1)
        x = x.transpose(0, 1).contiguous()
        x = self.dkan_fc1(x, weight_dkan, apply_relu_linear=True, relu_last=False)
        x = x.transpose(0, 1)
        x = self.bn_fc2(x)
        x = x.transpose(0, 1).contiguous()
        x = self.dkan_fc2(x, weight_dkan, apply_relu_linear=True, relu_last=False)
        x = x.transpose(0, 1)
        if self.fc2_output_slicing:
            x = x[:, :NUM_CLASSES]
        return x.contiguous()

    def get_frobenius_regularization(self):
        reg = 0.0
        for layer in self.patch_layers:
            reg += layer.get_frobenius_regularization()
        reg += self.dkan_fc1.get_frobenius_regularization()
        reg += self.dkan_fc2.get_frobenius_regularization()
        return reg

# -------------------------------------------------------
# Learning-rate / DKAN / Frobenius schedule (initial mode only)
# -------------------------------------------------------

def get_params(epoch: int):
    phase1_end = WARMUP_EPOCHS
    phase2_end = phase1_end + PURE_MLP_EPOCHS
    phase3_end = phase2_end + DKAN_TURN_ON_EPOCHS
    phase4_end = phase3_end + DKAN_FROBENIUS_DECAY_EPOCHS
    phase5_end = phase4_end + DKAN_LEARNING_RATE_DECAY_EPOCHS

    if epoch < phase1_end:  # Warm-up
        lr = PURE_MLP_BASE_LR * (epoch + 1) / max(1, WARMUP_EPOCHS)
        dkan_w = 0.0
        fro_w = INITIAL_FROBENIUS_WEIGHT
    elif epoch < phase2_end:  # Pure MLP
        lr = PURE_MLP_BASE_LR
        dkan_w = 0.0
        fro_w = INITIAL_FROBENIUS_WEIGHT
    elif epoch < phase3_end:  # DKAN turn-on
        offset = epoch - phase2_end
        lr = DKAN_BASE_LR
        dkan_w = min((offset + 1) / max(1, DKAN_TURN_ON_SCALE), DKAN_TURN_ON_CAP)
        fro_w = INITIAL_FROBENIUS_WEIGHT
    elif epoch < phase4_end:  # Frobenius decay
        offset = epoch - phase3_end
        lr = DKAN_BASE_LR
        dkan_w = min(DKAN_TURN_ON_EPOCHS / max(1, DKAN_TURN_ON_SCALE), DKAN_TURN_ON_CAP)
        fro_w = INITIAL_FROBENIUS_WEIGHT / (10 ** ((offset + 1) / max(1, DKAN_FROBENIUS_DECAY_SCALE)))
        fro_w = max(fro_w, FROBENIUS_WEIGHT_CAP)
    elif epoch < phase5_end:  # Final LR decay phase
        offset = epoch - phase4_end
        # Cosine decay from DKAN_BASE_LR → 0 over DKAN_LEARNING_RATE_DECAY_EPOCHS
        cos_inner = math.pi * offset / max(1, DKAN_LEARNING_RATE_DECAY_EPOCHS)
        lr = DKAN_BASE_LR * 0.5 * (1 + math.cos(cos_inner))
        dkan_w = min(DKAN_TURN_ON_EPOCHS / max(1, DKAN_TURN_ON_SCALE), DKAN_TURN_ON_CAP)
        fro_w = max(
            INITIAL_FROBENIUS_WEIGHT / (10 ** (DKAN_FROBENIUS_DECAY_EPOCHS / max(1, DKAN_FROBENIUS_DECAY_SCALE))),
            FROBENIUS_WEIGHT_CAP,
        )
    else:
        raise RuntimeError(f"Epoch {epoch} outside schedule range (0, {phase5_end})")

    return lr, dkan_w, fro_w

TOTAL_EPOCHS = (
    WARMUP_EPOCHS
    + PURE_MLP_EPOCHS
    + DKAN_TURN_ON_EPOCHS
    + DKAN_FROBENIUS_DECAY_EPOCHS
    + DKAN_LEARNING_RATE_DECAY_EPOCHS
)

# -------------------------------------------------------
# Model & optimiser
# -------------------------------------------------------

model = SimpleCNNDKAN(WIDTH, SCALE_TYPE).to(DEVICE)
optimizer = Adam(model.parameters(), lr=0.0)  # LR controlled by schedule
criterion = nn.CrossEntropyLoss()

# -------------------------------------------------------
# MixUp / CutMix utility
# -------------------------------------------------------

def mixup_cutmix(x, y):
    bs = x.size(0)
    idx = torch.randperm(bs, device=DEVICE)
    y2 = y[idx]
    if torch.rand(1, device=DEVICE) < CUTMIX_PROB and CUTMIX_BETA > 0:
        lam = torch.distributions.Beta(CUTMIX_BETA, CUTMIX_BETA).sample().item()
        H, W = x.shape[2:]
        cut_rat = math.sqrt(1 - lam)
        cw, ch = int(W * cut_rat), int(H * cut_rat)
        cx, cy = torch.randint(W, (1,), device=DEVICE).item(), torch.randint(H, (1,), device=DEVICE).item()
        x1, x2 = max(cx - cw // 2, 0), min(cx + cw // 2, W)
        y1, y2p = max(cy - ch // 2, 0), min(cy + ch // 2, H)
        x[:, :, y1:y2p, x1:x2] = x[idx, :, y1:y2p, x1:x2]
        lam = 1 - ((x2 - x1) * (y2p - y1) / (W * H))
    else:
        lam = torch.distributions.Beta(MIXUP_ALPHA, MIXUP_ALPHA).sample().item()
        x = lam * x + (1 - lam) * x[idx]
    return x, y, y2, lam

# -------------------------------------------------------
# Accuracy helper
# -------------------------------------------------------

def accuracy(logits, labels):
    return (logits.argmax(1) == labels).float().mean().item() * 100

# -------------------------------------------------------
# Training / validation
# -------------------------------------------------------

def train_epoch(epoch: int):
    lr, dkan_w, fro_w = get_params(epoch)
    for pg in optimizer.param_groups:
        pg["lr"] = lr

    model.train()
    mix_beta = torch.distributions.Beta(MIXUP_ALPHA, MIXUP_ALPHA)
    cut_beta = torch.distributions.Beta(CUTMIX_BETA, CUTMIX_BETA)

    acc_sum, loss_sum, pure_sum, n_samples = 0.0, 0.0, 0.0, 0

    for imgs_cpu, labels in train_loader:
        imgs = imgs_cpu.to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True)
        imgs = train_gpu_tf(imgs)

        # ----- MixUp / CutMix -----
        imgs, y1, y2, lam = mixup_cutmix(imgs, labels)

        optimizer.zero_grad(set_to_none=True)
        logits = model(imgs, dkan_w)

        pure_loss = lam * F.cross_entropy(logits, y1) + (1 - lam) * F.cross_entropy(logits, y2)
        loss = pure_loss + fro_w * model.get_frobenius_regularization()
        loss.backward()
        optimizer.step()

        n_samples += imgs.size(0)
        loss_sum += loss.item() * imgs.size(0)
        pure_sum += pure_loss.item() * imgs.size(0)
        acc_sum += accuracy(logits, labels) * imgs.size(0)

    return (
        loss_sum / n_samples,
        pure_sum / n_samples,
        acc_sum / n_samples,
        lr,
        dkan_w,
        fro_w,
    )


def eval_loader(loader, dkan_w: float):
    model.eval()
    acc_sum = 0.0
    with torch.no_grad():
        for imgs_cpu, labels in loader:
            imgs = imgs_cpu.to(DEVICE, non_blocking=True)
            labels = labels.to(DEVICE, non_blocking=True)
            imgs = val_gpu_tf(imgs)
            logits = model(imgs, dkan_w)
            acc_sum += accuracy(logits, labels)
    return acc_sum / len(loader)

# -------------------------------------------------------
# Main driver
# -------------------------------------------------------

def main():
    start_time = time.time()
    history = {k: [] for k in [
        "train_loss", "train_pure_loss", "train_acc", "val_acc", "lr", "dkan_w", "fro_w"
    ]}

    best_val, best_state, best_epoch = 0.0, None, -1

    epoch_iter = tqdm(range(TOTAL_EPOCHS), desc="Epochs")
    for epoch in epoch_iter:
        tr_loss, tr_pure, tr_acc, lr, dkan_w, fro_w = train_epoch(epoch)
        val_acc = eval_loader(val_loader, dkan_w)

        # Logging
        history["train_loss"].append(tr_loss)
        history["train_pure_loss"].append(tr_pure)
        history["train_acc"].append(tr_acc)
        history["val_acc"].append(val_acc)
        history["lr"].append(lr)
        history["dkan_w"].append(dkan_w)
        history["fro_w"].append(fro_w)

        best_so_far = max(best_val, val_acc)
        epoch_iter.set_postfix(
            train=f"{tr_acc:.2f}%",
            val=f"{val_acc:.2f}%",
            best=f"{best_so_far:.2f}%",
            lr=f"{lr:.1e}",
            dk_w=dkan_w,
            fr_w=fro_w,
        )

        if val_acc > best_val:
            best_val = val_acc
            best_epoch = epoch + 1
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}

    # Save final weights
    torch.save(model.cpu().state_dict(), PATH_SAVE / "final_model.pt")

    # Evaluate best model on test set
    if best_state is not None:
        torch.save(best_state, PATH_SAVE / "best_model.pt")
        model.load_state_dict(best_state)
        model.to(DEVICE)
    test_acc = eval_loader(test_loader, dkan_w=history["dkan_w"][-1])
    elapsed = time.time() - start_time

    # Save logs
    np.savez(PATH_SAVE / "logs.npz", **{k: np.array(v) for k, v in history.items()})

    # Summary
    with open(PATH_SAVE / "summary.txt", "w") as f:
        f.write(f"Best val acc: {best_val:.2f}% (epoch {best_epoch})\n")
        f.write(f"Test acc: {test_acc:.2f}%\n")
        f.write(f"Elapsed time (s): {elapsed:.1f}\n")

    shutil.copy(args.path_config, PATH_SAVE / "used_config.yaml")

    print(
        f"Best val acc: {best_val:.2f}% | Test acc: {test_acc:.2f}% | Elapsed: {elapsed:.1f}s.\n"
        f"Weights and summary saved to {PATH_SAVE}"
    )

if __name__ == "__main__":
    main()
