import torch
import math

#torch.set_default_dtype(torch.float32)

def calc_lgop(model, loader, criterion, optimizer, config, out_dim, act_fn='relu'):
    final_lgop = 0.0
    final_enfa = 0.0
    for idx, batch in enumerate(loader):
        batch = tuple(t.to(config.device) for t in batch)
        inputs, labels = batch
        labels = labels.to(inputs.dtype)

        inputs.requires_grad_(True)
        optimizer.zero_grad()

        nsamps = inputs.size(0)
        total_n =+ nsamps

        outputs = model(inputs, act_fn=act_fn)

        if out_dim == 1:
            if config.loss == 'mse':
                loss = torch.pow(outputs.squeeze() - labels, 2).mean()
            else:
                loss = criterion(outputs.squeeze(), labels)
        else:
            if config.loss == 'mse':
                loss = criterion(outputs, labels)
            else:
                loss = criterion(outputs, labels.argmax(-1).long())

        loss.backward()

        with torch.no_grad():
            final_lgop += (inputs.grad.t() @ inputs)*inputs.shape[0]
            final_enfa += (inputs.grad.t() @ inputs.grad)*inputs.shape[0]

    final_lgop = -1.0/(total_n * config.weight_decay) * final_lgop
    #final_lgop = 1.0/total_n * final_lgop
    final_enfa = 1.0/total_n * final_enfa
    return final_lgop.detach(), final_enfa.detach()

def calc_full_agop(model, loader, config, calc_per_class_agops=False, detach=True):
    dumb1 = torch.zeros((config.agop_batch_size, model.inp_dim)).to(config.device)
    total_n = 0
    final_agop = 0.0
    final_per_class_agops = []
    for idx, batch in enumerate(loader):
        # Copy data to device if needed
        batch = tuple(t.to(config.device) for t in batch)
        # Unpack the batch from the loader
        inputs, labels = batch
        labels = labels.to(inputs.dtype)

        nsamps = inputs.size(0)
        total_n += nsamps

        agop, per_class_agops = calc_batch_agop(model, inputs, dumb1, config.device, config, calc_per_class_agops=calc_per_class_agops,
                                                detach=detach)
        final_agop += agop * nsamps
        for jdx in range(len(per_class_agops)):
            if len(final_per_class_agops) < config.prime:
                final_per_class_agops.append(per_class_agops[jdx] * nsamps)
            else:
                final_per_class_agops[jdx] += per_class_agops[jdx] * nsamps

    final_agop /= total_n
    for jdx in range(len(per_class_agops)):
        final_per_class_agops[jdx] /= total_n
        if detach:
            final_per_class_agops[jdx] = final_per_class_agops[jdx].cpu()

    if detach:
        final_agop = final_agop.cpu()

    return final_agop, final_per_class_agops

def calc_batch_agop(model, inputs, dumb1, device, config, calc_per_class_agops=False, detach=True):
    if detach:
        jacs = torch.func.jacrev(model.forward, argnums=(1,))(inputs, dumb1, config.act_fn)[0].detach()
    else:
        jacs = torch.func.jacrev(model.forward, argnums=(1,))(inputs, dumb1, config.act_fn)[0]

    per_class_agops = []
    if calc_per_class_agops:
        for c_idx in range(config.prime):
            c_jac = jacs[:,c_idx,:,:].reshape(-1, model.inp_dim)
            per_class_agops.append(c_jac.t() @ c_jac / len(inputs))

    jacs = jacs.reshape(-1, model.inp_dim)
    agop = jacs.t() @ jacs / len(inputs)

    return agop, per_class_agops
