import torch
from train.training_handler import get_loss
import wandb
from utils.group_utils import get_gspace
import gc
def evaluate_classification_orbit(*,
                            model,
                            test_loader,
                            params,
                            wanbd_log=False,):
    model.eval()
    correct = 0
    total = 0
    total_loss = 0
    loss_function = get_loss(params)
    device = params.device 
    gspace = get_gspace(group_type=params.in_group_type, order=params.in_order, num_features=params.in_feature, representation=params.in_representation)

    model = model.to(device)

    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            for g in gspace.fibergroup.elements:
                img_t = gspace.transform(inputs, g)

                outputs = model(img_t)
                loss = loss_function(outputs, labels).sum()
                total_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                del outputs, img_t, loss
                gc.collect()

            del inputs, labels
            torch.cuda.empty_cache()
            gc.collect()
    
    values_to_log = {
        'test_accuracy_orbit': correct / total,
        'test_loss_orbit': total_loss / total
    }
    print(values_to_log)
    if wanbd_log:
        wandb.log(values_to_log, commit=True)

def evaluate_classification_smi(*,
                            model,
                            test_loader,
                            params,
                            wanbd_log=False,
                            random_noise=False,
                            eval_suffix=""):
    model.eval()
    correct = 0
    total = 0
    total_loss = 0
    loss_function = get_loss(params)
    device = params.device 
    gspace = get_gspace(group_type=params.in_group_type, order=params.in_order, num_features=params.in_feature, representation=params.in_representation)

    model = model.to(device)
    size = gspace.fibergroup.order()
    if params.in_group_type == 'dihedral':
        size = size//2
    elif params.in_group_type == 'cycle':
        size = size
    else:
        raise ValueError(f'Group type {params.in_group_type} not found')
    
    theshold = int(params.test_rotation /(360/size))

    print(f"Threshold: {theshold}")

    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            for g in gspace.fibergroup.elements:
                elm = g._element
                if isinstance(elm, (list, tuple)):
                    elm = elm[1]
                
                if elm > theshold and elm < size - theshold:
                    continue

                img_t = gspace.transform(inputs.clone(), g)
                if random_noise:
                    img_t = img_t + torch.randn_like(img_t) * 0.1

                outputs = model(img_t)
                loss = loss_function(outputs, labels).sum()
                total_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                del outputs, img_t, loss
                gc.collect()

            del inputs, labels
            torch.cuda.empty_cache()
            gc.collect()
    
    values_to_log = {
        eval_suffix+'semi_test_accuracy': correct / total,
        eval_suffix+'semi_test_loss': total_loss / total
    }
    print(values_to_log)
    if wanbd_log:
        wandb.log(values_to_log, commit=True)


def evaluate_classification(*,
                            model,
                            test_loader,
                            params,
                            wanbd_log=False,
                            eval_suffix="",
                            random_noise=False):
    model.eval()
    correct = 0
    total = 0
    total_loss = 0
    loss_function = get_loss(params)
    device = params.device 
    model = model.to(device)

    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            if random_noise:
                inputs = inputs + torch.randn_like(inputs) * 0.001
        
            outputs = model(inputs)
            loss = loss_function(outputs, labels).sum()
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            del inputs, labels, outputs, loss
            torch.cuda.empty_cache()
            gc.collect()
    
    values_to_log = {
        eval_suffix+'test_accuracy': correct / total,
        eval_suffix+'test_loss': total_loss / total
    }
    print(values_to_log)
    
    if wanbd_log:
        wandb.log(values_to_log, commit=True)

    return correct / total, total_loss / total
