import os
import argparse
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
import timm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from src.vit_classification import PostPoolingClassifier
from utils import set_seed


def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in tqdm(loader, desc="Training", leave=False):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Evaluating", leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

def run_experiment(pooling_type, train_loader, test_loader, device, num_classes, epochs, lr, model_type):
    """
        Function to run the classification based on the chosen pooling method. 
    """
    if model_type == "base_vit":
        timm_model_name = "vit_base_patch16_224"
    elif model_type == "small_vit":
        timm_model_name = "vit_small_patch16_224"
    else:
        raise ValueError("Unsupported model type. Choose 'base_vit' or 'small_vit'.")
        
    # Load the pre-trained model
    base_model = timm.create_model(
        timm_model_name,
        pretrained=True,
        num_classes=num_classes,
        global_pool=None
    )
    base_model.to(device)
    
    # We need to freeze the base model.
    for param in base_model.parameters():
        param.requires_grad = False
    
    model = PostPoolingClassifier(base_model,
                                  num_classes=num_classes,
                                  pooling_type=pooling_type,
                                  num_iterations=1)
    model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs} | Pooling: {pooling_type}")
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = evaluate(model, test_loader, criterion, device)
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | " +
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    
    # Final evaluation on test set.
    _, test_acc = evaluate(model, test_loader, criterion, device)
    print(f"Final Test Accuracy for pooling '{pooling_type}': {test_acc:.2f}%")
    return test_acc

def main():
    """
        Main function
    """
    parser = argparse.ArgumentParser(description="Fine-tune ViT with different pooling methods")
    parser.add_argument("--dataset", type=str, 
            choices=["CIFAR10", "CIFAR100", "OxfordIIITPet", "StanfordCars", "OxfordFlowers102"],
                        default="CIFAR10",
                        help="Dataset to use")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for reproducibility")
    parser.add_argument("--model_type", type=str, default="base_vit",
                        choices=["base_vit", "small_vit"],
                        help="Type of ViT model to use")
    parser.add_argument("--pooling", type=str, default="sum",
                    choices=["sum", "avg", "max", "weighted_avg", "attention"])
    args = parser.parse_args()

    # Hyperparameters
    BATCH_SIZE = 32
    EPOCHS = 10
    LR = 1e-2
    REPEATS = 2
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Dataset-specific settings.
    if args.dataset == "CIFAR10":
        NUM_CLASSES = 10
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        train_dataset = torchvision.datasets.CIFAR10(root='../data', train=True, download=False, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR10(root='../data', train=False, download=False, transform=transform_test)
        
    elif args.dataset == "CIFAR100":
        NUM_CLASSES = 100
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
        ])
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
        ])
        train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    
    elif args.dataset == "OxfordIIITPet":
        NUM_CLASSES = 37
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        train_dataset = torchvision.datasets.OxfordIIITPet(root='./data', split='trainval', download=False, transform=transform)
        test_dataset = torchvision.datasets.OxfordIIITPet(root='./data', split='test', download=False, transform=transform)
    
    elif args.dataset == "StanfordCars":
        NUM_CLASSES = 196
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        train_dataset = torchvision.datasets.StanfordCars(root="./data/torchvision-stanford-cars/", split='train', transform=transform, download=False)
        test_dataset = torchvision.datasets.StanfordCars(root="./data/torchvision-stanford-cars/", split='test', transform=transform, download=False)
    
    elif args.dataset == "OxfordFlowers102":
        NUM_CLASSES = 102
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        train_dataset = torchvision.datasets.Flowers102(root='./data', split='train', transform=transform, download=True)
        test_dataset = torchvision.datasets.Flowers102(root='./data', split='test', transform=transform, download=True)
    else:
        raise ValueError("Unsupported dataset selected.")
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    
    accuracies = []
    print(f"\nRunning experiments for pooling: {args.pooling}")
    for repeat in range(REPEATS):
        set_seed(args.seed + repeat)
        print(f"\n--- Repeat {repeat+1}/{REPEATS} for pooling: {args.pooling} ---")
        test_acc = run_experiment(args.pooling, train_loader, test_loader, DEVICE, NUM_CLASSES, EPOCHS, LR, args.model_type)
        accuracies.append(test_acc)
    mean_acc = np.mean(accuracies)
    std_acc = np.std(accuracies)
    print(f"\nPooling {args.pooling} -> Mean Test Acc: {mean_acc:.2f}%, Std: {std_acc:.2f}%\n")

if __name__ == "__main__":
    main()
