import torch
from escnn.group import *
import matplotlib.pyplot as plt
from layers.downsampling import SubgroupDownsample
from utils.group_utils import *
import matplotlib.pyplot as plt
import wandb
import gc

def equivariance_tester_classification(*,
                        model,
                        params,
                        dataloader= None,
                        device = 'cuda',
                        number_of_samples = 100,
                        wanbd_log= True):
    model = model.to(device)
    model.eval()

    in_group_type=params.in_group_type
    in_order=int(params.in_order)
    in_feature=int(params.in_feature)
    input_representation=params.in_representation
    out_group_type=params.out_group_type
    out_order= int(params.out_order)
    out_feature=int(params.out_feature)
    output_representation=params.out_representation

    G_in = get_group(in_group_type, in_order)

    print(in_feature, input_representation, out_feature, output_representation)
    gspace_in = get_gspace(group_type=in_group_type,order=in_order,num_features=in_feature, representation=input_representation)
    ab_error = 0
    rel_error = 0
    count = 0
    for v in dataloader:
        x = v[0].to(device)
        feature_x = model.get_feature(x)
        count+= x.shape[0]

        for g in G_in.elements:
            x_t = gspace_in.transform(x.clone(), g)
            feature_x_t = model.get_feature(x_t)
            t_error = torch.norm(feature_x.clone() - feature_x_t, p=2, dim=(-1))


            ab_error += (t_error**2).sum()
            rel_error += (t_error**2 / torch.norm(feature_x_t, p=2, dim=(-1))**2).sum()

            del x_t, feature_x_t, t_error
            gc.collect()
            torch.cuda.empty_cache()


        if count > number_of_samples:
            break

        del x, feature_x
        gc.collect()
        torch.cuda.empty_cache()
    
    values_to_log = {'ab_eq_error': ab_error/(count*G_in.order()), 'rel_eq_error': rel_error/(count*G_in.order())}
    print(values_to_log)
    if wanbd_log:
        wandb.log(values_to_log, commit=True)