import argparse
import math
from pathlib import Path
from typing import Tuple
import yaml
import time
import shutil

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms.v2 as T
from tqdm import tqdm

# -------------------------------------------------------
# Argument parsing (now *required* to supply both paths)
# -------------------------------------------------------
parser = argparse.ArgumentParser(description="2×2‑stride‑2 CNN for CIFAR‑10")
parser.add_argument("--path_config", type=Path, required=True,
                    help="YAML configuration file (e.g. ./config_mlp.yaml)")
parser.add_argument("--path_save", type=Path, required=True,
                    help="Directory where checkpoints & summary will be written")
args = parser.parse_args()

# -------------------------------------------------------
# Config loading
# -------------------------------------------------------
with open(args.path_config, "r") as f:
    cfg = yaml.safe_load(f)

WIDTH: int          = cfg["WIDTH"]
SCALE_TYPE: str     = cfg["SCALE_TYPE"]
BATCH_SIZE: int     = cfg["BATCH_SIZE"]
NUM_EPOCHS: int     = cfg["NUM_EPOCHS"]
N_WARMUP: int       = cfg["N_WARMUP"]
MIXUP_ALPHA: float  = float(cfg["MIXUP_ALPHA"])
CUTMIX_BETA: float  = float(cfg["CUTMIX_BETA"])
CUTMIX_PROB: float  = float(cfg["CUTMIX_PROB"])
LEARNING_RATE: float = float(cfg["LEARNING_RATE"])

# -------------------------------------------------------
# Paths / constants
# -------------------------------------------------------
PATH_SAVE: Path = args.path_save
PATH_SAVE.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device("cuda:0")
MEAN: Tuple[float, float, float] = (0.4914, 0.4822, 0.4465)
STD: Tuple[float, float, float]  = (0.2470, 0.2435, 0.2616)
SEED = 42

# -------------------------------------------------------
# Transforms (GPU)
# -------------------------------------------------------
cpu_to_tensor = T.ToTensor()

train_gpu_tf = nn.Sequential(
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.3, 0.3, 0.3, 0.05),
    T.RandAugment(2, 7),
    T.RandomErasing(p=0.25, scale=(0.05, 0.2), ratio=(0.3, 3.3)),
    T.Normalize(MEAN, STD),
).to(DEVICE)

val_gpu_tf = nn.Sequential(T.Normalize(MEAN, STD)).to(DEVICE)

# -------------------------------------------------------
# Data loaders
# -------------------------------------------------------

def make_loaders(batch_size: int):
    """
    Stratified 45 000 / 5 000 split so every class keeps the same ratio.
    A deterministic RNG (SEED) is used so the split is reproducible.
    """
    full_train = datasets.CIFAR10(
        root="./data", train=True, download=True, transform=cpu_to_tensor
    )

    targets = torch.as_tensor(full_train.targets)
    num_classes = targets.max().item() + 1          # 10 for CIFAR-10
    n_val_per_class = 5_000 // num_classes          # 500 images each

    g = torch.Generator().manual_seed(SEED)
    val_idx, train_idx = [], []

    for c in range(num_classes):
        cls_indices = torch.nonzero(targets == c, as_tuple=False).view(-1)
        cls_indices = cls_indices[torch.randperm(len(cls_indices), generator=g)]
        val_idx.extend(cls_indices[:n_val_per_class].tolist())
        train_idx.extend(cls_indices[n_val_per_class:].tolist())

    train_set = torch.utils.data.Subset(full_train, train_idx)
    val_set   = torch.utils.data.Subset(full_train, val_idx)

    test_set = datasets.CIFAR10(
        root="./data", train=False, download=True, transform=cpu_to_tensor
    )

    train_loader = DataLoader(
        train_set, batch_size=batch_size, shuffle=True,
        num_workers=4, pin_memory=True
    )
    val_loader = DataLoader(
        val_set, batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True
    )
    test_loader = DataLoader(
        test_set, batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True
    )
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = make_loaders(BATCH_SIZE)
print(f"Dataset sizes → train: {len(train_loader.dataset)}, val: {len(val_loader.dataset)}, test: {len(test_loader.dataset)}")

# -------------------------------------------------------
# Model definition
# -------------------------------------------------------
class TwoByTwoCNN(nn.Module):
    def __init__(self, width=16, scale_type="progressive", num_classes: int = 10):
        super().__init__()
        layers, in_ch, stage, size = [], 3, 0, 32
        out_ch = width
        while True:
            layers += [
                nn.Conv2d(in_ch, out_ch, kernel_size=2, stride=2),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
            ]
            stage += 1
            size //= 2
            if size == 1:
                break
            in_ch = out_ch
            out_ch = width * (2 ** stage) if scale_type == "progressive" else width
        self.features = nn.Sequential(*layers)
        self.classifier = nn.Sequential(
            nn.Linear(out_ch, out_ch),
            nn.ReLU(inplace=True),
            nn.Linear(out_ch, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        return self.classifier(x)


model = TwoByTwoCNN(WIDTH, SCALE_TYPE).to(DEVICE)

# -------------------------------------------------------
# Optimiser & LR schedule
# -------------------------------------------------------
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
steps_per_epoch = len(train_loader)
warmup_steps = N_WARMUP * steps_per_epoch
total_steps = NUM_EPOCHS * steps_per_epoch


def lr_lambda(step):
    if step < warmup_steps:
        return step / warmup_steps
    prog = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    return 0.5 * (1 + math.cos(math.pi * prog))

scheduler = LambdaLR(optimizer, lr_lambda)

# -------------------------------------------------------
# MixUp / CutMix helper
# -------------------------------------------------------

def mixup_cutmix(x, y):
    bs = x.size(0)
    idx = torch.randperm(bs, device=DEVICE)
    y2 = y[idx]
    if torch.rand(1, device=DEVICE) < CUTMIX_PROB:
        lam = torch.distributions.Beta(CUTMIX_BETA, CUTMIX_BETA).sample().item()
        H, W = x.shape[2:]
        cut_rat = math.sqrt(1 - lam)
        cw, ch = int(W * cut_rat), int(H * cut_rat)
        cx, cy = torch.randint(W, (1,), device=DEVICE).item(), torch.randint(H, (1,), device=DEVICE).item()
        x1, x2 = max(cx - cw // 2, 0), min(cx + cw // 2, W)
        y1, y2p = max(cy - ch // 2, 0), min(cy + ch // 2, H)
        x[:, :, y1:y2p, x1:x2] = x[idx, :, y1:y2p, x1:x2]
        lam = 1 - ((x2 - x1) * (y2p - y1) / (W * H))
    else:
        lam = torch.distributions.Beta(MIXUP_ALPHA, MIXUP_ALPHA).sample().item()
        x = lam * x + (1 - lam) * x[idx]
    return x, y, y2, lam

# -------------------------------------------------------
# Metrics helpers
# -------------------------------------------------------

def accuracy(logits, labels):
    return (logits.argmax(1) == labels).float().mean().item() * 100

# -------------------------------------------------------
# Train / evaluate
# -------------------------------------------------------

def train_epoch():
    model.train()
    acc_sum = 0.0
    for imgs_cpu, labels in train_loader:
        imgs = imgs_cpu.to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True)
        imgs = train_gpu_tf(imgs)
        imgs, y1, y2, lam = mixup_cutmix(imgs, labels)

        optimizer.zero_grad(set_to_none=True)
        logits = model(imgs)
        loss = lam * F.cross_entropy(logits, y1) + (1 - lam) * F.cross_entropy(logits, y2)
        loss.backward()
        optimizer.step()
        scheduler.step()

        acc_sum += accuracy(logits, labels)
    return acc_sum / len(train_loader)


def eval_loader(loader):
    """Run inference on the given loader and return accuracy (percentage)."""
    model.eval()
    acc_sum = 0.0
    with torch.no_grad():
        for imgs_cpu, labels in loader:
            imgs = imgs_cpu.to(DEVICE, non_blocking=True)
            labels = labels.to(DEVICE, non_blocking=True)
            imgs = val_gpu_tf(imgs)
            logits = model(imgs)
            acc_sum += accuracy(logits, labels)
    return acc_sum / len(loader)

# -------------------------------------------------------
# Main training driver
# -------------------------------------------------------

def main():
    start_time = time.time()
    best_val, best_state = 0.0, None
    best_epoch = -1 # Track the epoch of the best validation accuracy

    # Lists to store logs
    log_epochs = []
    log_lrs = []
    log_train_accs = []
    log_val_accs = []

    epoch_bar = tqdm(range(NUM_EPOCHS), desc="Epochs")
    for epoch in epoch_bar: # Use enumerate or track epoch manually
        train_acc = train_epoch()
        val_acc = eval_loader(val_loader)
        current_lr = optimizer.param_groups[0]['lr'] # Get current LR
        epoch_bar.set_postfix(train=f"{train_acc:.2f}%", val=f"{val_acc:.2f}%", lr=f"{current_lr:.1e}")

        # Append logs
        log_epochs.append(epoch + 1)
        log_lrs.append(current_lr)
        log_train_accs.append(train_acc)
        log_val_accs.append(val_acc)

        if val_acc > best_val:
            best_val = val_acc
            best_epoch = epoch + 1 # Store the 1-indexed epoch number
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}

    # Save final weights
    torch.save(model.cpu().state_dict(), PATH_SAVE / "final_model.pt")

    # Evaluate best model on test set
    if best_state:
        torch.save(best_state, PATH_SAVE / "best_model.pt")
        model.load_state_dict(best_state)
        model.to(DEVICE)
    test_acc = eval_loader(test_loader)
    elapsed = time.time() - start_time

    # Save logs to npz file
    np.savez(
        PATH_SAVE / "logs.npz",
        epoch=np.array(log_epochs),
        lr=np.array(log_lrs),
        train_acc=np.array(log_train_accs),
        val_acc=np.array(log_val_accs)
    )

    # Write summary
    with open(PATH_SAVE / "summary.txt", "w") as f:
        f.write(f"Best validation accuracy: {best_val:.2f}% (Epoch {best_epoch})\n")
        f.write(f"Test accuracy: {test_acc:.2f}%\n")
        f.write(f"Elapsed time (s): {elapsed:.1f}\n")

    # Copy the config file used for this run
    shutil.copy(args.path_config, PATH_SAVE / "used_config.yaml")

    print(
        f"Best val acc: {best_val:.2f}% | Test acc: {test_acc:.2f}% | Elapsed: {elapsed:.1f}s.\n"
        f"Weights and summary saved to {PATH_SAVE}"
    )


if __name__ == "__main__":
    main()
