import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from src.utils import get_optimizer, get_dataloaders
from src.utils.checkpointing import save_evaluation_checkpoint, load_evaluation_checkpoint
from src.models import build_encoder


def linear_probing(args, logger, encoder, device):
    encoder.eval()

    for param in encoder.parameters():
        param.requires_grad = False

    # Training
    classifier = nn.Linear(args.feature_dim, args.num_classes).to(device)
    optimizer = get_optimizer(args, classifier, args.evaluation_lr, args.evaluation_weight_decay, args.momentum, evaluation=True)

    train_loader, test_loader = get_dataloaders(args, logger, args.evaluation_batch_size, augmentation=False, supervised=True)

    for epoch in range(args.evaluation_epochs):
        classifier.train()
        total_losses = []

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            with torch.no_grad():
                features = encoder(images)

            outputs = classifier(features)
            loss = F.cross_entropy(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_losses.append(loss.detach().item())
        
        average_train_loss = np.mean(total_losses)
        logger.info(f"Epoch-{epoch} | Loss: {average_train_loss:.4f}") if epoch % 10 == 0 else None

    # Validation
    classifier.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            features = encoder(images)
            outputs = classifier(features)

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total

    return accuracy


def supervised_evaluation(args, logger, model, test_loader, device):
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            logits = model(images)
            _, predicted = torch.max(logits, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    
    return accuracy


def evaluate_model(args, logger, device, run_paths, model, epoch, evaluation_type="linear_probing", test_loader=None):
    if not args.supervised:
        if evaluation_type == "linear_probing":
            logger.info("Evaluating model using linear probing...")

            if args.model_name == "SDMI":
                E_encoder = build_encoder(args).to(device)
                M_encoder = build_encoder(args).to(device)
                encoder_list = [E_encoder, M_encoder]

                save_evaluation_checkpoint(args, logger, model, run_paths.checkpoint_directory)
                load_evaluation_checkpoint(args, logger, encoder_list, run_paths.checkpoint_directory, device)
                
                accuracy_E = linear_probing(args, logger, E_encoder, device) * 100
                accuracy_M = linear_probing(args, logger, M_encoder, device) * 100

                model.E_linear_probing_history[f"Epoch-{epoch}"] = accuracy_E
                model.M_linear_probing_history[f"Epoch-{epoch}"] = accuracy_M
                logger.info(f"Linear probing accuracy after epoch-{epoch}: E-encoder: {accuracy_E:.2f}% | M-encoder: {accuracy_M:.2f}%")
            
            else:
                encoder = build_encoder(args).to(device)
                encoder_list = [encoder]

                save_evaluation_checkpoint(args, logger, model, run_paths.checkpoint_directory)
                load_evaluation_checkpoint(args, logger, encoder_list, run_paths.checkpoint_directory, device)

                accuracy = linear_probing(args, logger, encoder, device) * 100
                
                model.linear_probing_history[f"Epoch-{epoch}"] = accuracy
                logger.info(f"Linear probing accuracy after epoch-{epoch}: {accuracy:.2f}%")

    else:
        accuracy = supervised_evaluation(args, logger, model, test_loader, device) * 100
        model.evaluation_history[f"Epoch-{epoch}"] = accuracy
        logger.info(f"Supervised evaluation accuracy after epoch-{epoch}: {accuracy:.2f}%")

    logger.info("Model evaluation completed.")
