import os
import json
import argparse
import logging
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models
from torch.cuda.amp import GradScaler, autocast

try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False

from foveated_dataloaders import FocLVanillaBoxDataset, MultiGlimpseDistortionAwareDataset

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(
        description="FocL single-crop training (vanilla or multi-glimpse)"
    )
    parser.add_argument(
        "--data-root", type=str, required=True,
        help="Root directory of dataset (ImageFolder structure)"
    )
    parser.add_argument(
        "--annotation-folder", type=str, required=True,
        help="Directory containing XML annotations"
    )
    parser.add_argument(
        "--splits-json", type=str, required=True,
        help="JSON file with 'train' and 'val' index lists"
    )
    parser.add_argument(
        "--dataset-type", choices=["vanilla", "multiglimpse"], default="vanilla",
        help="Which FocL dataset to use"
    )
    parser.add_argument(
        "--multi-glimpse-num", type=int, default=3,
        help="Number of glimpses for multi-glimpse dataset"
    )
    parser.add_argument(
        "--offset-fraction", type=float, default=0.2,
        help="Offset fraction for multi-glimpse"
    )
    parser.add_argument(
        "--scale-jitter", type=float, default=0.1,
        help="Scale jitter for multi-glimpse"
    )
    parser.add_argument(
        "--area-threshold", type=float, default=0.2,
        help="Area threshold for distortion-aware expand"
    )
    parser.add_argument(
        "--augmentation-mode", choices=["conservative", "medium", "aggressive"], default="medium",
        help="Augmentation mode for multi-glimpse"
    )
    parser.add_argument(
        "--max-crop-ratio", type=float, default=0.2,
        help="Max crop ratio for distortion-aware expand"
    )
    parser.add_argument(
        "--resize-size", type=int, nargs=2, default=[224, 224],
        help="Resize size (height width)"
    )
    parser.add_argument("--epochs", type=int, default=90, help="Number of epochs")
    parser.add_argument("--batch-size", type=int, default=128, help="Batch size")
    parser.add_argument("--learning-rate", type=float, default=0.1, help="Learning rate")
    parser.add_argument("--weight-decay", type=float, default=1e-4, help="Weight decay")
    parser.add_argument("--step-size", type=int, default=30, help="LR scheduler step size")
    parser.add_argument("--gamma", type=float, default=0.1, help="LR scheduler gamma")
    parser.add_argument(
        "--save-dir", type=str, default="./checkpoints",
        help="Directory to save checkpoints"
    )
    parser.add_argument(
        "--project-name", type=str, default="focl_single_crop",
        help="W&B project name"
    )
    parser.add_argument(
        "--use-wandb", action="store_true",
        help="Enable Weights & Biases logging"
    )
    parser.add_argument("--num-workers", type=int, default=4, help="Number of DataLoader workers")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    return parser.parse_args()


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def main():
    args = parse_args()
    os.makedirs(args.save_dir, exist_ok=True)

    # Reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    g = torch.Generator()
    g.manual_seed(args.seed)

    # Initialize W&B
    if args.use_wandb and WANDB_AVAILABLE:
        wandb.init(project=args.project_name, config=vars(args))
        logger.info("Weights & Biases logging enabled")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    # Transforms for cropping
    train_crop_transform = transforms.Compose([
        transforms.RandomResizedCrop(tuple(args.resize_size), scale=(0.8, 1.0), ratio=(0.85, 1.15)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    test_crop_transform = transforms.Compose([
        transforms.Resize(tuple(args.resize_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Load split indices
    with open(args.splits_json, "r") as f:
        splits = json.load(f)
    train_indices = splits.get("train", [])
    val_indices = splits.get("val", [])

    # Base dataset
    full_ds = datasets.ImageFolder(root=args.data_root)
    train_sub = Subset(full_ds, train_indices)
    val_sub = Subset(datasets.ImageFolder(root=args.data_root), val_indices)

    # Instantiate dataset class
    if args.dataset_type == "vanilla":
        ds_class = FocLVanillaBoxDataset
        train_kwargs = {
            "subset": train_sub,
            "annotation_folder": args.annotation_folder,
            "crop_transform": train_crop_transform,
            "resize_size": tuple(args.resize_size)
        }
        val_kwargs = {
            "subset": val_sub,
            "annotation_folder": args.annotation_folder,
            "crop_transform": test_crop_transform,
            "resize_size": tuple(args.resize_size)
        }
    else:
        ds_class = MultiGlimpseDistortionAwareDataset
        base_kwargs = {
            "annotation_folder": args.annotation_folder,
            "resize_size": tuple(args.resize_size),
            "offset_fraction": args.offset_fraction,
            "scale_jitter": args.scale_jitter,
            "area_threshold": args.area_threshold,
            "augmentation_mode": args.augmentation_mode,
            "num_glimpses": args.multi_glimpse_num,
            "max_crop_ratio": args.max_crop_ratio,
            "multi_crop": False
        }
        train_kwargs = {
            "subset": train_sub,
            "crop_transform": train_crop_transform,
            "train_mode": True,
            **base_kwargs
        }
        val_kwargs = {
            "subset": val_sub,
            "crop_transform": test_crop_transform,
            "train_mode": False,
            **base_kwargs
        }

    train_dataset = ds_class(**train_kwargs)
    val_dataset = ds_class(**val_kwargs)

    # DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=g
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=g
    )

    # Model, loss, optimizer, scheduler
    model = models.resnet50(pretrained=False, num_classes=len(full_ds.classes)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    scaler = GradScaler()

    best_val_acc = 0.0
    for epoch in range(1, args.epochs + 1):
        # Training
        model.train()
        total_loss, total_correct = 0.0, 0
        for full_imgs, crop_imgs, labels, _ in train_loader:
            inputs, labels = crop_imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            preds = outputs.argmax(dim=1)
            total_loss += loss.item() * inputs.size(0)
            total_correct += (preds == labels).sum().item()
        scheduler.step()
        train_loss = total_loss / len(train_loader.dataset)
        train_acc = total_correct / len(train_loader.dataset)
        logger.info(f"Epoch {epoch}/{args.epochs} - Train loss: {train_loss:.4f}, acc: {train_acc:.4f}")
        if args.use_wandb and WANDB_AVAILABLE:
            wandb.log({"train_loss": train_loss, "train_accuracy": train_acc})

        # Validation
        model.eval()
        val_loss, val_correct = 0.0, 0
        with torch.no_grad():
            for full_imgs, crop_imgs, labels, _ in val_loader:
                inputs, labels = crop_imgs.to(device), labels.to(device)
                with autocast():
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                preds = outputs.argmax(dim=1)
                val_loss += loss.item() * inputs.size(0)
                val_correct += (preds == labels).sum().item()
        val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_correct / len(val_loader.dataset)
        logger.info(f"Epoch {epoch}/{args.epochs} - Val loss: {val_loss:.4f}, acc: {val_acc:.4f}")
        if args.use_wandb and WANDB_AVAILABLE:
            wandb.log({"val_loss": val_loss, "val_accuracy": val_acc})

        # Save best
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), os.path.join(args.save_dir, "best_model.pth"))

    # Final save
    torch.save(model.state_dict(), os.path.join(args.save_dir, "final_model.pth"))
    if args.use_wandb and WANDB_AVAILABLE:
        wandb.save(os.path.join(args.save_dir, "final_model.pth"))

if __name__ == "__main__":
    main()
