import os
import json
import argparse
import logging
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models
from torch.cuda.amp import GradScaler, autocast

try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False

from foveated_dataloaders import MultiGlimpseDistortionAwareDataset

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(description="FocL multi-glimpse training (distortion-aware)")
    parser.add_argument("--data-root", type=str, required=True, help="Path to ImageNet training folder")
    parser.add_argument("--annotation-folder", type=str, required=True, help="Directory with XML annotations")
    parser.add_argument("--splits-json", type=str, required=True, help="JSON file with 'train' and 'val' index lists")
    parser.add_argument("--num-glimpses", type=int, default=3, help="Number of glimpses per image")
    parser.add_argument("--offset-fraction", type=float, default=0.2, help="Offset fraction for crops")
    parser.add_argument("--scale-jitter", type=float, default=0.1, help="Scale jitter fraction for crops")
    parser.add_argument("--area-threshold", type=float, default=0.2, help="Area threshold for distortion-aware expansion")
    parser.add_argument(
        "--augmentation-mode", choices=["conservative","medium","aggressive"],
        default="medium", help="Augmentation mode"
    )
    parser.add_argument("--max-crop-ratio", type=float, default=0.2, help="Max crop ratio for distortion-aware expansion")
    parser.add_argument("--resize-size", type=int, nargs=2, default=[224,224], help="Resize dimensions H W")
    parser.add_argument("--epochs", type=int, default=90, help="Number of epochs")
    parser.add_argument("--batch-size", type=int, default=64, help="Batch size")
    parser.add_argument("--learning-rate", type=float, default=0.1, help="Initial learning rate")
    parser.add_argument("--weight-decay", type=float, default=1e-4, help="Weight decay")
    parser.add_argument("--step-size", type=int, default=30, help="LR scheduler step size (in epochs)")
    parser.add_argument("--gamma", type=float, default=0.1, help="LR scheduler gamma")
    parser.add_argument("--save-dir", type=str, default="./checkpoints", help="Directory to save models")
    parser.add_argument("--project-name", type=str, default="focl_multi_glimpse", help="W&B project name")
    parser.add_argument("--use-wandb", action="store_true", help="Enable Weights & Biases logging")
    parser.add_argument("--num-workers", type=int, default=4, help="DataLoader workers")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    return parser.parse_args()


def seed_worker(worker_id):
    seed = torch.initial_seed() % (2**32)
    np.random.seed(seed)
    random.seed(seed)


def get_transforms(resize_size):
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(tuple(resize_size), scale=(0.8,1.0), ratio=(0.85, 1.15)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2,contrast=0.2,saturation=0.2,hue=0.1),
        transforms.RandomGrayscale(p=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])
    val_transform = transforms.Compose([
        transforms.Resize(tuple(resize_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])
    return train_transform, val_transform


def train_one_epoch(model, loader, criterion, optimizer, scaler, device):
    model.train()
    total_loss, total_correct, total_samples = 0.0, 0, 0
    for _, crops, labels, _ in loader:
        # crops: [B,N,C,H,W]
        B,N,C,H,W = crops.shape
        inputs = crops.view(B*N, C, H, W)
        labs = labels.unsqueeze(1).expand(B,N).reshape(-1)
        inputs, labs = inputs.to(device), labs.to(device)

        optimizer.zero_grad()
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labs)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        preds = outputs.argmax(dim=1)
        total_loss += loss.item() * inputs.size(0)
        total_correct += (preds == labs).sum().item()
        total_samples += inputs.size(0)

    return total_loss/total_samples, total_correct/total_samples


def validate(model, loader, criterion, device):
    model.eval()
    total_loss, total_correct, total_samples = 0.0, 0, 0
    with torch.no_grad():
        for _, crops, labels, _ in loader:
            # single-crop validation: crops shape [B,C,H,W]
            if crops.ndim == 5:
                B,N,C,H,W = crops.shape
                inputs = crops.view(B*N, C, H, W)
                labs = labels.unsqueeze(1).expand(B,N).reshape(-1)
            else:
                inputs, labs = crops, labels
            inputs, labs = inputs.to(device), labs.to(device)
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labs)
            preds = outputs.argmax(dim=1)
            total_loss += loss.item() * inputs.size(0)
            total_correct += (preds == labs).sum().item()
            total_samples += inputs.size(0)
    return total_loss/total_samples, total_correct/total_samples


def main():
    args = parse_args()
    os.makedirs(args.save_dir, exist_ok=True)

    # Reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    gen = torch.Generator()
    gen.manual_seed(args.seed)

    if args.use_wandb and WANDB_AVAILABLE:
        wandb.init(project=args.project_name, config=vars(args))
        logger.info("W&B logging enabled")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    train_transform, val_transform = get_transforms(args.resize_size)

    with open(args.splits_json, 'r') as f:
        splits = json.load(f)
    train_indices = splits.get('train', [])
    val_indices = splits.get('val', [])

    full_ds = datasets.ImageFolder(root=args.data_root)
    train_sub = Subset(full_ds, train_indices)
    val_sub = Subset(datasets.ImageFolder(root=args.data_root), val_indices)

    train_dataset = MultiGlimpseDistortionAwareDataset(
        subset=train_sub,
        annotation_folder=args.annotation_folder,
        crop_transform=train_transform,
        resize_size=tuple(args.resize_size),
        train_mode=True,
        offset_fraction=args.offset_fraction,
        scale_jitter=args.scale_jitter,
        area_threshold=args.area_threshold,
        augmentation_mode=args.augmentation_mode,
        num_glimpses=args.num_glimpses,
        max_crop_ratio=args.max_crop_ratio,
        multi_crop=True
    )
    val_dataset = MultiGlimpseDistortionAwareDataset(
        subset=val_sub,
        annotation_folder=args.annotation_folder,
        crop_transform=val_transform,
        resize_size=tuple(args.resize_size),
        train_mode=False,
        offset_fraction=args.offset_fraction,
        scale_jitter=args.scale_jitter,
        area_threshold=args.area_threshold,
        augmentation_mode=args.augmentation_mode,
        num_glimpses=args.num_glimpses,
        max_crop_ratio=args.max_crop_ratio,
        multi_crop=False
    )

    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers, pin_memory=True,
        worker_init_fn=seed_worker, generator=gen
    )
    val_loader = DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True,
        worker_init_fn=seed_worker, generator=gen
    )

    model = models.resnet50(pretrained=False, num_classes=len(full_ds.classes)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    scaler = GradScaler()

    best_val_acc = 0.0
    for epoch in range(1, args.epochs+1):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, scaler, device)
        scheduler.step()
        val_loss, val_acc = validate(model, val_loader, criterion, device)

        logger.info(f"Epoch {epoch}/{args.epochs} | Train loss: {train_loss:.4f}, acc: {train_acc:.4f} | "
                    f"Val loss: {val_loss:.4f}, acc: {val_acc:.4f}")
        if args.use_wandb and WANDB_AVAILABLE:
            wandb.log({"train_loss": train_loss, "train_acc": train_acc,
                       "val_loss": val_loss, "val_acc": val_acc}, step=epoch)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), os.path.join(args.save_dir, 'best_model.pth'))

    torch.save(model.state_dict(), os.path.join(args.save_dir, 'final_model.pth'))
    if args.use_wandb and WANDB_AVAILABLE:
        wandb.save(os.path.join(args.save_dir, 'final_model.pth'))


if __name__ == '__main__':
    main()
