import os
import json
import argparse
import logging

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models
from torch.cuda.amp import GradScaler, autocast

try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(description="Standard ResNet50 training with optional W&B integration")
    parser.add_argument("--data-root", type=str, required=True, help="Path to ImageNet training folder")
    parser.add_argument("--splits-json", type=str, required=True, help="JSON file with 'train' and 'val' index lists")
    parser.add_argument("--save-dir", type=str, default="./checkpoints", help="Directory to save checkpoints")
    parser.add_argument("--project-name", type=str, default="imagenet_resnet50", help="Weights & Biases project name")
    parser.add_argument("--epochs", type=int, default=90, help="Number of training epochs")
    parser.add_argument("--batch-size", type=int, default=128, help="Batch size")
    parser.add_argument("--learning-rate", type=float, default=0.1, help="Initial learning rate")
    parser.add_argument("--weight-decay", type=float, default=1e-4, help="Weight decay factor")
    parser.add_argument("--step-size", type=int, default=30, help="LR scheduler step size (in epochs)")
    parser.add_argument("--gamma", type=float, default=0.1, help="LR scheduler gamma")
    parser.add_argument("--use-wandb", action="store_true", help="Enable Weights & Biases logging")
    parser.add_argument("--num-workers", type=int, default=4, help="Number of DataLoader workers")
    return parser.parse_args()


def main():
    args = parse_args()
    os.makedirs(args.save_dir, exist_ok=True)

    if args.use_wandb and WANDB_AVAILABLE:
        wandb.init(project=args.project_name, config=vars(args))
        logger.info("Weights & Biases logging enabled")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    # Data augmentation and transforms
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    test_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Load split indices
    with open(args.splits_json, "r") as f:
        splits = json.load(f)
    train_indices = splits.get("train", [])
    val_indices = splits.get("val", [])

    # Datasets and subsets
    train_ds = datasets.ImageFolder(root=args.data_root, transform=train_transform)
    val_ds = datasets.ImageFolder(root=args.data_root, transform=test_transform)
    train_subset = Subset(train_ds, train_indices)
    val_subset = Subset(val_ds, val_indices)

    train_loader = DataLoader(train_subset, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True)
    val_loader = DataLoader(val_subset, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.num_workers, pin_memory=True)

    # Model, loss, optimizer, scheduler
    model = models.resnet50(pretrained=False, num_classes=len(train_ds.classes)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.learning_rate,
                          momentum=0.9, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    scaler = GradScaler()

    best_val_acc = 0.0
    for epoch in range(1, args.epochs + 1):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            with autocast():
                outputs = model(imgs)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            preds = outputs.argmax(dim=1)
            train_loss += loss.item() * imgs.size(0)
            train_correct += (preds == labels).sum().item()
        scheduler.step()
        epoch_train_loss = train_loss / len(train_loader.dataset)
        epoch_train_acc = train_correct / len(train_loader.dataset)
        logger.info(f"Epoch {epoch}/{args.epochs} - Train loss: {epoch_train_loss:.4f}, acc: {epoch_train_acc:.4f}")
        if args.use_wandb and WANDB_AVAILABLE:
            wandb.log({"train_loss": epoch_train_loss, "train_accuracy": epoch_train_acc})

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                with autocast():
                    outputs = model(imgs)
                    loss = criterion(outputs, labels)
                preds = outputs.argmax(dim=1)
                val_loss += loss.item() * imgs.size(0)
                val_correct += (preds == labels).sum().item()
        epoch_val_loss = val_loss / len(val_loader.dataset)
        epoch_val_acc = val_correct / len(val_loader.dataset)
        logger.info(f"Epoch {epoch}/{args.epochs} - Val loss: {epoch_val_loss:.4f}, acc: {epoch_val_acc:.4f}")
        if args.use_wandb and WANDB_AVAILABLE:
            wandb.log({"val_loss": epoch_val_loss, "val_accuracy": epoch_val_acc})

        # Save best model
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            torch.save(model.state_dict(), os.path.join(args.save_dir, "best_model.pth"))

    # Final save
    torch.save(model.state_dict(), os.path.join(args.save_dir, "final_model.pth"))
    if args.use_wandb and WANDB_AVAILABLE:
        wandb.save(os.path.join(args.save_dir, "final_model.pth"))


if __name__ == "__main__":
    main()
