import argparse
import os
import sys
import logging
import pandas as pd
import numpy as np
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import timm
from torchvision import transforms
from PIL import Image

from sklearn.metrics import accuracy_score, roc_auc_score
from data.pairs_dataset.dataset import KneeFeatureDataset

# -----------------------------
# Utility functions for logging distribution.
# -----------------------------
def compute_distribution(dataset):
    labels = [label for _, label in dataset.samples]
    counter = Counter(labels)
    total = len(labels)
    prevalence = {cls: count/total for cls, count in counter.items()}
    return counter, prevalence

def log_distribution(log_file, split_name, dataset):
    counter, prevalence = compute_distribution(dataset)
    with open(log_file, "a") as f:
        f.write(f"\nDistribution for {split_name} set:\n")
        for cls, count in sorted(counter.items()):
            f.write(f"  Class {cls} (value {dataset.unique_labels[cls]}): {count} samples, prevalence {prevalence[cls]:.4f}\n")

# -----------------------------
# Training and evaluation functions.
# -----------------------------
def evaluate(model, dataloader, device, num_classes):
    model.eval()
    all_preds = []
    all_targets = []
    all_probs = []
    running_loss = 0.0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            probs = torch.softmax(outputs, dim=1)
            all_probs.extend(probs.cpu().numpy())
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())
    avg_loss = running_loss / len(dataloader.dataset)
    acc = accuracy_score(all_targets, all_preds)
    try:
        targets_onehot = np.eye(num_classes)[all_targets]
        macro_auc = roc_auc_score(targets_onehot, np.array(all_probs), multi_class="ovr")
    except Exception as e:
        macro_auc = 0.0
    per_class_acc = {}
    for cls in range(num_classes):
        idxs = [i for i, t in enumerate(all_targets) if t == cls]
        if idxs:
            cls_acc = accuracy_score(np.array(all_targets)[idxs], np.array(all_preds)[idxs])
            per_class_acc[cls] = cls_acc
    counter = Counter(all_targets)
    prevalence = {cls: counter[cls]/len(all_targets) for cls in range(num_classes)}
    return avg_loss, acc, macro_auc, per_class_acc, prevalence

def train_model(model, train_loader, val_loader, device, num_classes, num_epochs, log_file):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-5)
    best_val_macro_auc = -float("inf")
    best_model_state = None
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
        avg_loss = running_loss / len(train_loader.dataset)
        
        # Evaluate on validation set.
        val_loss, val_acc, val_macro_auc, val_per_class_acc, val_prevalence = evaluate(model, val_loader, device, num_classes)
        
        log_str = (
            f"\nEpoch {epoch+1} Training Loss: {avg_loss:.4f}\n"
            f"Validation Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}, Macro AUC: {val_macro_auc:.4f}\n"
            f"Per Class Accuracy: {val_per_class_acc}\n"
            f"Class Prevalence: {val_prevalence}\n"
        )
        print(log_str)
        with open(log_file, "a") as f:
            f.write(log_str)
        
        # Save best model based on macro AUC.
        if val_macro_auc > best_val_macro_auc:
            best_val_macro_auc = val_macro_auc
            best_model_state = model.state_dict()
    
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return model

# -----------------------------
# Main function.
# -----------------------------
def main():
    parser = argparse.ArgumentParser(description="Pretrain a knee feature detection model.")
    parser.add_argument("feature", type=str, help="Feature to train on (e.g., JSN_Lateral)")
    parser.add_argument("--epochs", type=int, default=30, help="Number of epochs")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
    args = parser.parse_args()

    feature = args.feature
    num_epochs = args.epochs
    batch_size = args.batch_size

    # Set up log file.
    log_file = f"checkpoints/{feature}/train_{feature}.log"
    os.makedirs(os.path.dirname(log_file), exist_ok=True)

    with open(log_file, "w") as f:
        f.write(f"Training log for feature: {feature}\n")
    
    # CSV directory.
    base_dir = "data/pairs_dataset"
    train_csv = os.path.join(base_dir, "train.csv")
    val_csv = os.path.join(base_dir, "val.csv")
    test_csv = os.path.join(base_dir, "test.csv")
    
    # Define transforms.
    data_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),        
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    eval_transforms = transforms.Compose([
        transforms.Resize((224, 224)),      
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    
    # Create datasets.
    train_dataset = KneeFeatureDataset(train_csv, "train", feature, transform=data_transforms)
    val_dataset = KneeFeatureDataset(val_csv, "val", feature, transform=eval_transforms)
    test_dataset = KneeFeatureDataset(test_csv, "test", feature, transform=eval_transforms)
    
    # Log data distributions.
    with open(log_file, "a") as f:
        f.write("\nTrain Set Distribution:\n")
    log_distribution(log_file, "Train", train_dataset)
    with open(log_file, "a") as f:
        f.write("\nValidation Set Distribution:\n")
    log_distribution(log_file, "Validation", val_dataset)
    with open(log_file, "a") as f:
        f.write("\nTest Set Distribution:\n")
    log_distribution(log_file, "Test", test_dataset)
    
    # Create DataLoaders.
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    # Determine number of classes.
    num_classes = len(train_dataset.unique_labels)
    print(f"Number of classes: {num_classes}")
    
    # Set device.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # -----------------------------
    # Load a pretrained EfficientFormerV2-L model.
    # -----------------------------
    # Using the 'efficientformerv2_l.snap_dist_in1k' variant from timm and resetting the classifier head.
    model = timm.create_model('efficientformerv2_l.snap_dist_in1k', pretrained=True)
    model.reset_classifier(num_classes)
    model = model.to(device)
    
    # Train the model using best validation macro AUC.
    model = train_model(model, train_loader, val_loader, device, num_classes, num_epochs, log_file)
    
    # Evaluate on test set.
    test_loss, test_acc, test_macro_auc, test_per_class_acc, test_prevalence = evaluate(model, test_loader, device, num_classes)
    test_log = (
        "\nTest Set Evaluation:\n"
        f"Test Loss: {test_loss:.4f}, Accuracy: {test_acc:.4f}, Macro AUC: {test_macro_auc:.4f}\n"
        f"Per Class Accuracy: {test_per_class_acc}\n"
        f"Class Prevalence: {test_prevalence}\n"
    )
    print(test_log)
    with open(log_file, "a") as f:
        f.write(test_log)
    
    # Save the best model (based on macro AUC) to a file.
    model_save_path = f"checkpoints/{feature}/{feature}_model.pth"
    torch.save(model.state_dict(), model_save_path)
    print(f"Best model saved to {model_save_path}")

if __name__ == "__main__":
    main()
