#!/usr/bin/env python3
"""
Full-finetuning script for MobileNetV4 on Caltech-256 dataset
This script provides a complete training pipeline with data augmentation,
learning rate scheduling, early stopping, and model checkpointing.
Supports both single dataset and split dataset structures.
"""

import os
import argparse
import time
import copy
import json
from pathlib import Path

import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import get_model
import torch.nn.functional as F

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt


class Caltech256Dataset:
    """Custom dataset class for Caltech-256 with proper class mapping"""
    
    def __init__(self, root_dir, transform=None, train=True):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.train = train
        
        # Get all class directories
        self.classes = sorted([d.name for d in self.root_dir.iterdir() if d.is_dir()])
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        
        # Collect all image paths and labels
        self.images = []
        self.labels = []
        
        for class_name in self.classes:
            class_dir = self.root_dir / class_name
            if class_dir.exists():
                for img_path in class_dir.glob("*.jpg"):
                    self.images.append(str(img_path))
                    self.labels.append(self.class_to_idx[class_name])
        
        print(f"Found {len(self.images)} images across {len(self.classes)} classes")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        
        # Load image using PIL
        from PIL import Image
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


def load_split_dataset(data_root, transform, split_name):
    """
    Load a specific split from a split dataset structure.
    
    Args:
        data_root (str): Root directory containing train/val/test splits
        transform: Data transforms to apply
        split_name (str): Name of the split ('train', 'val', 'test')
    
    Returns:
        Caltech256Dataset: Dataset for the specified split
    """
    split_path = Path(data_root) / split_name
    if not split_path.exists():
        raise ValueError(f"Split directory {split_path} does not exist!")
    
    return Caltech256Dataset(root_dir=split_path, transform=transform, train=(split_name == 'train'))


def get_transforms(crop_size=384, resize_size=384):
    """Get training and validation transforms"""
    
    # Training transforms with data augmentation
    train_transform = transforms.Compose([
        transforms.CenterCrop((crop_size + 32, crop_size + 32)),  # Slightly larger for random crop
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.Resize((resize_size, resize_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Validation transforms (no augmentation)
    val_transform = transforms.Compose([
        transforms.Resize((resize_size, resize_size), 
                          interpolation=transforms.InterpolationMode.BICUBIC,
                          antialias=True),
        transforms.CenterCrop((crop_size, crop_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, val_transform


def create_model(num_classes, add_two_fc=True, 
                add_fc=False,
                pretrained=True,
                freeze_backbone=False,
                num_layers_to_unfreeze=None):
    """Create and configure MobileNetV4 model"""
    
 
    model = timm.create_model('mobilenetv4_conv_aa_large.e230_r448_in12k_ft_in1k', pretrained=True)
    data_config = timm.data.resolve_model_data_config(model)
    model.transforms = timm.data.create_transform(**data_config, is_training=False)
    print("Loaded pretrained MobileNetV4 with ImageNet weights")

    
    # Freeze backbone layers if requested
    if freeze_backbone:
        for param in model.parameters():
            param.requires_grad = False
        print("Frozen backbone layers")

    if num_layers_to_unfreeze is not None:
        unfrozen_layers = 0
        for name, child in list(model.named_children())[::-1][2:]:
            if unfrozen_layers >= num_layers_to_unfreeze:
                break
            for param in child.parameters():
                param.requires_grad = True
            unfrozen_layers += 1
            print(f"Unfrozen layer {name}")
        
    if add_two_fc:
        num_features = model.classifier.in_features
        model.classifier = nn.Sequential(
            nn.Linear(num_features, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, num_classes)
        )
        print("Added two fully connected layers")

    elif add_fc:
        num_features = model.classifier.in_features
        model.classifier = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )
        print("Added fully connected layer")
    else:
        # Replace the final classification layer
        num_features = model.classifier.in_features
        model.classifier = nn.Linear(num_features, num_classes)
    
    return model


def train_epoch(model, dataloader, criterion, optimizer, device, epoch):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc=f'Epoch {epoch+1} - Training')
    
    for batch_idx, (inputs, targets) in enumerate(pbar):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        targets_one_hot = F.one_hot(targets, num_classes=outputs.shape[1]).float()
        loss = criterion(outputs, targets_one_hot)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'Loss': f'{running_loss/(batch_idx+1):.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc


def validate_epoch(model, dataloader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc='Validation')
        
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            targets_one_hot = F.one_hot(targets, num_classes=outputs.shape[1]).float()
            loss = criterion(outputs, targets_one_hot)
            
            # Statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'Loss': f'{running_loss/(batch_idx+1):.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc


def test_epoch(model, dataloader, criterion, device):
    """Test for one epoch"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc='Testing')
        
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            targets_one_hot = F.one_hot(targets, num_classes=outputs.shape[1]).float()
            loss = criterion(outputs, targets_one_hot)
            
            # Statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Store predictions and targets for detailed analysis
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            
            # Update progress bar
            pbar.set_postfix({
                'Loss': f'{running_loss/(batch_idx+1):.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc, all_predictions, all_targets


def save_checkpoint(model, optimizer, scheduler, epoch, best_acc, save_path):
    """Save model checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'best_acc': best_acc,
    }
    torch.save(checkpoint, save_path)
    print(f"Checkpoint saved to {save_path}")


def load_checkpoint(model, optimizer, scheduler, checkpoint_path, num_layers_to_unfreeze=None):
    """Load model checkpoint"""
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    if num_layers_to_unfreeze is None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        print("Optimizer state dict not loaded due to num_layers_to_unfreeze")
    try:
        if scheduler and checkpoint['scheduler_state_dict']:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    except:
        pass
    try:
        start_epoch = checkpoint['epoch'] + 1
    except:
        start_epoch = 0
    best_acc = checkpoint['best_acc']
    print(f"Loaded checkpoint from epoch {start_epoch} with best acc {best_acc:.2f}%")
    return start_epoch, best_acc


def plot_training_curves(train_losses, train_accs, val_losses, val_accs, save_path):
    """Plot training and validation curves"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot losses
    ax1.plot(train_losses, label='Training Loss', color='blue')
    ax1.plot(val_losses, label='Validation Loss', color='red')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Plot accuracies
    ax2.plot(train_accs, label='Training Accuracy', color='blue')
    ax2.plot(val_accs, label='Validation Accuracy', color='red')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('Training and Validation Accuracy')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


def main():
    parser = argparse.ArgumentParser(description='Finetune MobileNetV4 on Caltech-256')
    parser.add_argument('--data_dir', type=str, required=True,
                       help='Path to Caltech-256 dataset directory (or split dataset root)')
    parser.add_argument('--output_dir', type=str, default='./CalTech-256/mbnet_v4_checkpoints',
                       help='Directory to save checkpoints and logs')
    parser.add_argument('--batch_size', type=int, default=64,
                       help='Batch size for training')
    parser.add_argument('--epochs', type=int, default=50,
                       help='Number of training epochs')
    parser.add_argument('--lr', type=float, default=0.001,
                       help='Initial learning rate')
    parser.add_argument('--lr_not_last_layer', type=float, default=0.00001,
                       help='Use a different learning rate for the Non-last layer')
    parser.add_argument('--weight_decay', type=float, default=1e-4,
                       help='Weight decay')
    parser.add_argument('--crop_size', type=int, default=448,
                       help='Input image crop size')
    parser.add_argument('--resize_size', type=int, default=471,
                       help='Input image resize size')
    parser.add_argument('--num_workers', type=int, default=2,
                       help='Number of data loading workers')
    parser.add_argument('--pretrained', action='store_true', default=True,
                       help='Use pretrained ImageNet weights')
    parser.add_argument('--freeze_backbone', action='store_true',
                       help='Freeze backbone layers during training')
    parser.add_argument('--resume', type=str, default=None,
                       help='Path to checkpoint to resume from')
    parser.add_argument('--num_layers_to_unfreeze', type=int, default=None,
                       help='Number of layers to unfreeze from the end of the model')
    parser.add_argument('--early_stopping_patience', type=int, default=10,
                       help='Early stopping patience')
    parser.add_argument('--use_split_dataset', action='store_true', default=True,
                       help='Use split dataset structure (train/val/test directories)')
    parser.add_argument('--test_on_completion', action='store_true', default=True,
                       help='Run test evaluation after training completion')
    parser.add_argument('--add_fc', action='store_true', default=False,
                       help='Add a fully connected layer after the last layer of the model')
    parser.add_argument('--add_two_fc', action='store_true', default=False,
                       help='Add two fully connected layers after the last layer of the model')
    args = parser.parse_args()
    
    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Set device
    if torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Get transforms
    train_transform, val_transform = get_transforms(args.crop_size, args.resize_size)
    
    # Create datasets
    print("Loading datasets...")
    
    if args.use_split_dataset:
        # Use split dataset structure
        print("Using split dataset structure...")
        train_dataset = load_split_dataset(args.data_dir, train_transform, 'train')
        val_dataset = load_split_dataset(args.data_dir, val_transform, 'val')
        
        if args.test_on_completion:
            test_dataset = load_split_dataset(args.data_dir, val_transform, 'test')
    else:
        # Use single dataset (for backward compatibility)
        print("Using single dataset structure...")
        train_dataset = Caltech256Dataset(
            root_dir=args.data_dir,
            transform=train_transform,
            train=True
        )
        val_dataset = Caltech256Dataset(
            root_dir=args.data_dir,
            transform=val_transform,
            train=False
        )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True
    )
    
    if args.test_on_completion and args.use_split_dataset:
        test_loader = DataLoader(
            test_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True
        )
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    if args.test_on_completion and args.use_split_dataset:
        print(f"Test samples: {len(test_dataset)}")
    print(f"Number of classes: {len(train_dataset.classes)}")
    
    # Create model
    model = create_model(
        num_classes=len(train_dataset.classes),
        add_fc=args.add_fc,
        add_two_fc=args.add_two_fc,
        pretrained=args.pretrained,
        freeze_backbone=args.freeze_backbone,
        num_layers_to_unfreeze=args.num_layers_to_unfreeze
    )
    model = model.to(device)
    
    # Loss function and optimizer
    criterion = nn.BCEWithLogitsLoss()
    if args.num_layers_to_unfreeze is not None:
        names = [names for names, _ in list(model.named_children())][::-1]
        lr_schedule = []
        for idx in range(args.num_layers_to_unfreeze+1):
            lr_schedule_entry = {}
            if idx == 0:
                params = getattr(model, names[idx])
                lr_schedule_entry['params'] = params.parameters()
                lr_schedule_entry['lr'] = args.lr
            else:
                params = getattr(model, names[idx])
                lr_schedule_entry['params'] = params.parameters()
                lr_schedule_entry['lr'] = args.lr_not_last_layer            
            lr_schedule.append(lr_schedule_entry)
        print("LR schedule: ", lr_schedule)

        optimizer = optim.AdamW(lr_schedule,
                                 weight_decay=args.weight_decay)

    else:
        optimizer = optim.AdamW(
            model.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay
        )
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.5,
        patience=5)
    
    # Training variables
    start_epoch = 0
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    patience_counter = 0
    
    # Resume from checkpoint if specified
    if args.resume:
        start_epoch, best_acc = load_checkpoint(
            model, optimizer, scheduler, args.resume, args.num_layers_to_unfreeze
        )
    
    # Training history
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    
    print("Starting training...")
    start_time = time.time()
    
    for epoch in range(start_epoch, args.epochs):
        print(f"\nEpoch {epoch+1}/{args.epochs}")
        print("-" * 20)
        
        # Train
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device, epoch
        )
        
        # Validate
        val_loss, val_acc = validate_epoch(
            model, val_loader, criterion, device
        )
        
        # Update learning rate
        scheduler.step(val_acc)
        
        # Save history
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        # Print epoch summary
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            patience_counter = 0
            
            # Save checkpoint
            checkpoint_path = output_dir / f'best_model_epoch_{epoch+1}_acc_{best_acc:.2f}.pth'
            save_checkpoint(model, optimizer, scheduler, epoch, best_acc, checkpoint_path)
        else:
            patience_counter += 1
        
        # Save regular checkpoint
        if (epoch + 1) % 10 == 0:
            checkpoint_path = output_dir / f'checkpoint_epoch_{epoch+1}.pth'
            save_checkpoint(model, optimizer, scheduler, epoch, best_acc, checkpoint_path)
        
        # Early stopping
        if args.early_stopping_patience > 0:
            if patience_counter >= args.early_stopping_patience:
                print(f"Early stopping triggered after {args.early_stopping_patience} epochs without improvement")
                break
    
    # Training completed
    total_time = time.time() - start_time
    print(f"\nTraining completed in {total_time/3600:.2f} hours")
    print(f"Best validation accuracy: {best_acc:.2f}%")
    
    # Load best model weights
    model.load_state_dict(best_model_wts)
    
    # Test evaluation if requested
    if args.test_on_completion and args.use_split_dataset:
        print("\nRunning test evaluation...")
        test_loss, test_acc, test_predictions, test_targets = test_epoch(
            model, test_loader, criterion, device
        )
        print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
        
        # Save test results
        test_results = {
            'test_loss': test_loss,
            'test_accuracy': test_acc,
            'test_predictions': test_predictions,
            'test_targets': test_targets
        }
        test_results_path = output_dir / 'test_results.json'
        with open(test_results_path, 'w') as f:
            json.dump(test_results, f, indent=2, default=str)
        print(f"Test results saved to: {test_results_path}")
    
    # Save final model
    final_model_path = output_dir / 'final_model.pth'
    torch.save({
        'model_state_dict': model.state_dict(),
        'best_acc': best_acc,
        'num_classes': len(train_dataset.classes),
        'class_names': train_dataset.classes,
        'class_to_idx': train_dataset.class_to_idx
    }, final_model_path)
    print(f"Final model saved to {final_model_path}")
    
    # Plot training curves
    plot_path = output_dir / 'training_curves.png'
    plot_training_curves(train_losses, train_accs, val_losses, val_accs, plot_path)
    
    # Save training history
    history = {
        'train_losses': train_losses,
        'train_accs': train_accs,
        'val_losses': val_losses,
        'val_accs': val_accs,
        'best_acc': best_acc,
        'total_epochs': len(train_losses)
    }
    history_path = output_dir / 'training_history.npz'
    np.savez(history_path, **history)
    print(f"Training history saved to {history_path}")
    
    print("Training script completed successfully!")


if __name__ == '__main__':
    main() 