import os
import time
import argparse
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


# =====================
# Args
# =====================
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--exp-name", type=str, required=True)
    parser.add_argument("--dataset", type=str, choices=["cifar", "mnist","fashion"], default="cifar")
    parser.add_argument("--single-timescale", action="store_true")
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--epochs", type=int, default=200)
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--lr-body", type=float, default=1e-4)
    parser.add_argument("--lr-head", type=float, default=1e-3)
    parser.add_argument("--lambda-head", type=float, default=1e-4)
    parser.add_argument("--logdir", type=str, default="runs/classification")
    return parser.parse_args()


# =====================
# Models
# =====================
class CIFARNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),        # 16x16
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),        # 8x8
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )
        self.head = nn.Linear(256, num_classes)

    def forward(self, x):
        return self.head(self.body(x))



class MNISTNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.body = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
        )
        self.head = nn.Linear(256, num_classes)

    def forward(self, x):
        return self.head(self.body(x))


# =====================
# Utils
# =====================
def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def grad_norm(params):
    total = 0.0
    for p in params:
        if p.grad is not None:
            total += p.grad.data.norm(2).item() ** 2
    return total ** 0.5


# =====================
# Main
# =====================
def main():
    args = parse_args()
    set_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    run_name = f"{args.dataset}__{args.exp_name}__seed{args.seed}__{int(time.time())}"
    writer = SummaryWriter(os.path.join(args.logdir, run_name))

    # =====================
    # Data
    # =====================
    if args.dataset in ["cifar"]:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=(0.4914, 0.4822, 0.4465),
                std=(0.2470, 0.2435, 0.2616),
            ),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=(0.4914, 0.4822, 0.4465),
                std=(0.2470, 0.2435, 0.2616),
            ),
        ])

        trainset = datasets.CIFAR10("./data", train=True, download=True, transform=transform_train)
        testset  = datasets.CIFAR10("./data", train=False, download=True, transform=transform_test)
        model = CIFARNet(num_classes=10).to(device)

    elif args.dataset == "mnist":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        trainset = datasets.MNIST(
            "./data", train=True, download=True, transform=transform
        )
        testset = datasets.MNIST(
            "./data", train=False, download=True, transform=transform
        )
        model = MNISTNet().to(device)

    else:  # Fashion-MNIST
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.2860,), (0.3530,))
        ])
        trainset = datasets.FashionMNIST(
            "./data", train=True, download=True, transform=transform
        )
        testset = datasets.FashionMNIST(
            "./data", train=False, download=True, transform=transform
        )
        model = MNISTNet().to(device)

    train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4)
    test_loader  = DataLoader(testset, batch_size=256, shuffle=False, num_workers=4)

    # =====================
    # Optimizer
    # =====================
    if args.single_timescale:
        optimizer = optim.RMSprop(model.parameters(), lr=args.lr_body)
    else:
        optimizer = optim.RMSprop([
            {"params": model.body.parameters(), "lr": args.lr_body},
            {"params": model.head.parameters(), "lr": args.lr_head},
        ])

    # =====================
    # Training Loop
    # =====================
    global_step = 0
    start_time = time.time()

    for epoch in range(args.epochs):
        model.train()
        epoch_loss = 0.0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            logits = model(x)
            ce_loss = F.cross_entropy(logits, y)

            if args.single_timescale:
                loss = ce_loss
            else:
                l2_head = sum((p ** 2).sum() for p in model.head.parameters())
                loss = ce_loss + args.lambda_head * l2_head

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            writer.add_scalar("train/loss", loss.item(), global_step)
            writer.add_scalar("train/ce_loss", ce_loss.item(), global_step)
            writer.add_scalar("grads/body_norm", grad_norm(model.body.parameters()), global_step)
            writer.add_scalar("grads/head_norm", grad_norm(model.head.parameters()), global_step)

            epoch_loss += loss.item()
            global_step += 1

        # =====================
        # Eval
        # =====================
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                pred = model(x).argmax(dim=1)
                correct += (pred == y).sum().item()
                total += y.size(0)

        acc = correct / total
        avg_loss = epoch_loss / len(train_loader)

        print(
            f"[{args.dataset.upper()}] Epoch {epoch:03d} | "
            f"Loss {avg_loss:.4f} | Acc {acc:.4f}"
        )

        writer.add_scalar("eval/accuracy", acc, epoch)
        writer.add_scalar("charts/epoch_loss", avg_loss, epoch)
        writer.add_scalar("charts/elapsed_time", time.time() - start_time, epoch)

    writer.close()


if __name__ == "__main__":
    main()
