import torch
import torch.nn.functional as F
import gc
import random
import numpy as np
from tqdm import tqdm
from torch_geometric.data import Batch

def set_seed(seed):
    """
    Set all relevant random seeds for reproducibility.
    """
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def train_vision_model(model, train_loader, optimizer, device, classification=True):
    """
    Train a vision model for one epoch.
    """
    model.train()
    total_loss, total = 0, 0
    if classification:
        correct = 0
    else:
        total_abs_error = 0

    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        if classification:
            loss = F.cross_entropy(outputs, labels)
        else:
            labels = labels.float()
            loss = F.mse_loss(outputs.squeeze(1), labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * images.size(0)
        if classification:
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
        else:
            total_abs_error += torch.abs(outputs.squeeze() - labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / total
    if classification:
        accuracy = correct / total
        return avg_loss, accuracy
    else:
        mean_absolute_error = total_abs_error / total
        return avg_loss, mean_absolute_error


@torch.no_grad()
def test_vision_model(model, test_loader, device, classification=True):
    """
    Evaluate the vision model.
    """
    model.eval()
    all_preds, all_labels = [], []
    total = 0

    if classification:
        correct = 0
        for images, labels in tqdm(test_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            all_preds.extend(predicted.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())
        accuracy = correct / total
        return accuracy, all_preds, all_labels

    else:
        total_abs_error = 0.0
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            outputs = outputs.squeeze()
            total_abs_error += torch.abs(outputs - labels).sum().item()
            total += labels.size(0)
            all_preds.extend(outputs.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())
        mean_abs_error = total_abs_error / total
        return mean_abs_error, all_preds, all_labels
    

def train_gnn_model(model, train_loader, optimizer, device, classification=True):
    """
    Train a GNN model for one epoch.
    """
    model.train()
    total_loss = 0.0
    total = 0
    
    if classification:
        correct = 0
    else:
        total_abs_error = 0.0
        
    for data in tqdm(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        if classification:
            loss = F.cross_entropy(out, data.y)
        else:
            loss = F.mse_loss(out.squeeze(), data.y)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * data.num_graphs
        total += data.num_graphs
        
        if classification:
            pred = out.argmax(dim=1)
            correct += (pred == data.y).sum().item()
        else:
            total_abs_error += torch.abs(out.squeeze() - data.y).sum().item()
    
    avg_loss = total_loss / total
    if classification:
        accuracy = correct / total
        return avg_loss, accuracy
    else:
        mae = total_abs_error / total
        return avg_loss, mae

@torch.no_grad()
def test_gnn_model(model, test_loader, device, classification=True):
    """
    Evaluate a GNN model.
    """
    model.eval()
    total = 0
    all_preds, all_labels = [], []
    
    if classification:
        correct = 0
    else:
        total_abs_error = 0.0
        
    for data in tqdm(test_loader):
        data = data.to(device)
        out = model(data)
        if classification:
            pred = out.argmax(dim=1)
            correct += (pred == data.y).sum().item()
            total += data.num_graphs
            all_preds.extend(pred.cpu().tolist())
            all_labels.extend(data.y.cpu().tolist())
        else:
            out = out.squeeze()
            total_abs_error += torch.abs(out - data.y).sum().item()
            total += data.num_graphs
            all_preds.extend(out.cpu().tolist())
            all_labels.extend(data.y.cpu().tolist())
    
    if classification:
        accuracy = correct / total
        return accuracy, all_preds, all_labels
    else:
        mae = total_abs_error / total
        return mae, all_preds, all_labels

def release_resources(model, optimizer):
    """
    Release memory used by the model and optimizer (useful for large models or GPU cleanup).
    """
    model.cpu()
    del model
    del optimizer
    torch.cuda.empty_cache()
    gc.collect()