import argparse
import os
import time
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models

from models.resnet import ResNet9, ResNet18, CifarResNet18
from models.lenet import LeNet
from datasets.load_datasets import load_dataset


def parse_args():
    parser = argparse.ArgumentParser(description='Training module')
    parser.add_argument('--model', type=str, default='resnet9',
                        choices=['resnet9', 'resnet18', 'allcnn', 'lenet'],
                        help='Select model architecture')
    parser.add_argument('--dataset', type=str, default='cifar10',
                        choices=['mnist', 'cifar10', 'cifar100', 'svhn', 'tinyimagenet'],
                        help='Training dataset')
    parser.add_argument('--batch-size', type=int, default=128,
                        help='Batch size')
    parser.add_argument('--epochs', type=int, default=100,
                        help='Training epochs')
    parser.add_argument('--lr', type=float, default=0.01,
                        help='Learning rate')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='Momentum')
    parser.add_argument('--weight-decay', type=float, default=5e-4,
                        help='Weight decay')
    parser.add_argument('--save-dir', type=str, default='checkpoints',
                        help='Directory to save models')
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed')
    return parser.parse_args()


def get_dataset(dataset_name, batch_size=128):
    """Load specified dataset"""
    train_loader, test_loader, num_classes = load_dataset(
        dataset_name, batch_size=batch_size, num_workers=4,
    )

    if dataset_name == 'mnist':
        in_channels = 1
    else:
        in_channels = 3

    return train_loader, test_loader, in_channels, num_classes


def get_model(model_name, in_channels, num_classes, dataset_name):
    """Select specified model"""
    if model_name == 'resnet9':
        return ResNet9(num_classes=num_classes, in_channels=in_channels)
    elif model_name == 'resnet18':
        if dataset_name == 'tinyimagenet':
            # Load a pretrained ResNet18 model and adapt it for TinyImageNet
            model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, num_classes)
            return model
        else: # For CIFAR10, CIFAR100
            return CifarResNet18(num_classes=num_classes, in_channels=in_channels)
    elif model_name == 'lenet':
        return LeNet(num_classes=num_classes, in_channels=in_channels)
    else:
        raise ValueError(f"Unsupported model: {model_name}")


def init_weights(m):
    """Initialize network weights"""
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)


def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    acc = 100.0 * correct / total
    return total_loss / len(train_loader), acc


def validate(model, val_loader, criterion, device):
    """Validate model"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    acc = 100.0 * correct / total
    return total_loss / len(val_loader), acc


def main():
    args = parse_args()

    torch.manual_seed(args.seed)

    os.makedirs(args.save_dir, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    print(f"Loading dataset: {args.dataset}")
    train_loader, test_loader, in_channels, num_classes = get_dataset(
        args.dataset, args.batch_size)

    print(f"Creating model: {args.model}")
    model = get_model(args.model, in_channels, num_classes, args.dataset)
    if not (args.model == 'resnet18' and args.dataset == 'tinyimagenet'):
        model.apply(init_weights)
    model = model.to(device)

    if args.dataset in ['cifar100', 'tinyimagenet']:
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        if args.dataset == 'tinyimagenet' and args.model == 'resnet18':
            initial_lr = 0.01  # Lower LR for fine-tuning
        else:
            initial_lr = 0.1 if args.model == 'resnet18' else args.lr
    else:
        criterion = nn.CrossEntropyLoss()
        initial_lr = args.lr

    if args.model == 'lenet':
        optimizer = optim.Adam(model.parameters(), lr=0.001,
                               weight_decay=args.weight_decay)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.1, patience=5, verbose=True)
    else:
        optimizer = optim.SGD(model.parameters(), lr=initial_lr,
                              momentum=args.momentum, weight_decay=args.weight_decay)

        if args.dataset in ['cifar100', 'tinyimagenet']:
            optimizer = optim.SGD(model.parameters(), lr=initial_lr,
                                  momentum=args.momentum, weight_decay=1e-3)

        if args.dataset in ['cifar100', 'tinyimagenet'] and args.model == 'resnet18':
            warmup_epochs = 5

            def warmup_cosine_schedule(epoch):
                if epoch < warmup_epochs:
                    return epoch / warmup_epochs
                return 0.5 * (1 + math.cos(math.pi * (epoch - warmup_epochs) / (args.epochs - warmup_epochs)))

            scheduler = optim.lr_scheduler.LambdaLR(
                optimizer, lr_lambda=warmup_cosine_schedule)
        else:
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=args.epochs)

    best_acc = 0.0
    for epoch in range(args.epochs):
        start_time = time.time()

        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device)

        test_loss, test_acc = validate(model, test_loader, criterion, device)

        if args.model == 'lenet':
            scheduler.step(test_acc)
        else:
            scheduler.step()

        epoch_time = time.time() - start_time
        print(f"Epoch: {epoch+1}/{args.epochs} | Time: {epoch_time:.2f}s")
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")
        print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")

        if test_acc > best_acc:
            best_acc = test_acc
            save_path = os.path.join(
                args.save_dir, f"{args.model}_{args.dataset}_best.pth")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_acc': best_acc,
            }, save_path)
            print(f"Model saved to {save_path}")

    print(f"Training completed! Best accuracy: {best_acc:.2f}%")


if __name__ == "__main__":
    main()
