import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from src.models.snip_utils import compare_flops_and_params

### Standard Classification test function
def test(net, testloader, device: str, accuracy_scale=1., freeze_bn_buffer=True, jit_augment=None):
    """Validate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0.0, 0.0, 0.0
    if freeze_bn_buffer:
        net.eval()
    else:
        net.train()
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)

            if images.size(1) == 1: # single-channel
                images = images.expand(-1, 3, images.shape[2], images.shape[3])

            if jit_augment is not None:
                images = jit_augment(images)
            outputs = net(images)

            if type(outputs) == list:
                assert len(outputs) == 1
                outputs = outputs[0]
            loss += criterion(outputs, labels).item() * labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    loss = loss / total
    
    # converting to the required range
    accuracy = accuracy * accuracy_scale

    return loss, accuracy, total

def ree_early_exit_test(
    model,
    max_early_exit_layer,
    testloader,
    device="cpu",
    ensemble=True,
):
    model.eval()
    orig_flops, orig_params = compare_flops_and_params(model, device=device)
    if orig_flops:
        print(f"New model:  {orig_flops} | {orig_params}")
    criterion = torch.nn.CrossEntropyLoss()

    # Probe once to learn exit count
    probe_images, probe_labels = None, None
    for imgs, lbls in testloader:
        probe_images, probe_labels = imgs.to(device), lbls.to(device)
        break
    if probe_images is None:
        return {"ensemble_accuracy": 0.0}

    with torch.inference_mode():
        outs = model(probe_images)
    if torch.is_tensor(outs):
        outs = [outs]
    outs = [o for o in outs if o is not None]  # <-- robust to missing heads
    n_exits = len(outs)
    if n_exits == 0:
        raise RuntimeError("ree_early_exit_test: model returned no exits")

    max_ee = min(int(max_early_exit_layer), n_exits - 1)
    results = {ee: {"loss": 0.0, "correct": 0, "total": 0} for ee in range(max_ee + 1)}
    ensemble_correct, ensemble_total = 0, 0

    with torch.inference_mode():
        for images, labels in testloader:
            images = images.to(device)
            labels = labels.to(device)

            if images.size(1) == 1:
                images = images.expand(-1, 3, images.shape[2], images.shape[3])

            outputs = model(images)
            if torch.is_tensor(outputs):
                outputs = [outputs]
            outputs = [o for o in outputs if o is not None]  # <-- robust

            if not outputs:
                continue  # skip batch if model produced nothing (shouldn't happen)

            cur_max_ee = min(max_ee, len(outputs) - 1)

            for ee in range(cur_max_ee + 1):
                logits = outputs[ee]
                results[ee]["loss"] += criterion(logits, labels).item() * labels.size(0)
                pred = torch.argmax(logits, dim=1)
                results[ee]["correct"] += (pred == labels).sum().item()
                results[ee]["total"] += labels.size(0)

            if ensemble and cur_max_ee >= 0:
                stacked = torch.stack([outputs[i] for i in range(cur_max_ee + 1)], dim=0)
                logits_ens = torch.mean(stacked, dim=0)
                pred_ens = torch.argmax(logits_ens, dim=1)
                ensemble_correct += (pred_ens == labels).sum().item()
                ensemble_total += labels.size(0)

    for ee in range(max_ee + 1):
        if results[ee]["total"] > 0:
            results[ee]["loss"] /= results[ee]["total"]
            results[ee]["accuracy"] = results[ee]["correct"] / results[ee]["total"]
        else:
            results[ee]["loss"] = float("nan")
            results[ee]["accuracy"] = 0.0

    if ensemble and ensemble_total > 0:
        results["ensemble_accuracy"] = ensemble_correct / ensemble_total

    return results

def __ree_early_exit_test(
    model,
    max_early_exit_layer,
    testloader,
    device="cpu",
    ensemble=True,
):
    model.eval()
    orig_flops, orig_params = compare_flops_and_params(model, device=device)
    if orig_flops:
        print(f"New model:  {orig_flops} | {orig_params}")
    criterion = torch.nn.CrossEntropyLoss()

    # --- Probe the model to find how many exits it returns ---
    probe_images, probe_labels = None, None
    for imgs, lbls in testloader:
        probe_images, probe_labels = imgs.to(device), lbls.to(device)
        break
    if probe_images is None:
        return {"ensemble_accuracy": 0.0}

    with torch.inference_mode():
        outs = model(probe_images)
    if torch.is_tensor(outs):
        outs = [outs]
    n_exits = len(outs)
    if n_exits == 0:
        raise RuntimeError("ree_early_exit_test: model returned no exits")

    max_ee = min(int(max_early_exit_layer), n_exits - 1)
    results = {ee: {"loss": 0.0, "correct": 0, "total": 0} for ee in range(max_ee + 1)}
    ensemble_correct, ensemble_total = 0, 0

    # --- Full pass ---
    with torch.inference_mode():
        for images, labels in testloader:
            images = images.to(device)
            labels = labels.to(device)

            # match your other eval path for grayscale datasets
            if images.size(1) == 1:
                images = images.expand(-1, 3, images.shape[2], images.shape[3])

            outputs = model(images)
            if torch.is_tensor(outputs):
                outputs = [outputs]

            cur_max_ee = min(max_ee, len(outputs) - 1)

            for ee in range(cur_max_ee + 1):
                logits = outputs[ee]
                results[ee]["loss"] += criterion(logits, labels).item() * labels.size(0)
                pred = torch.argmax(logits, dim=1)
                results[ee]["correct"] += (pred == labels).sum().item()
                results[ee]["total"] += labels.size(0)

            if ensemble and cur_max_ee >= 0:
                stacked = torch.stack([outputs[i] for i in range(cur_max_ee + 1)], dim=0)
                logits_ens = torch.mean(stacked, dim=0)
                pred_ens = torch.argmax(logits_ens, dim=1)
                ensemble_correct += (pred_ens == labels).sum().item()
                ensemble_total += labels.size(0)

    for ee in range(max_ee + 1):
        if results[ee]["total"] > 0:
            results[ee]["loss"] /= results[ee]["total"]
            results[ee]["accuracy"] = results[ee]["correct"] / results[ee]["total"]
        else:
            results[ee]["loss"] = float("nan")
            results[ee]["accuracy"] = 0.0

    if ensemble and ensemble_total > 0:
        results["ensemble_accuracy"] = ensemble_correct / ensemble_total

    return results

def epochs_to_batches(num_epochs, dataset_len, batch_size, drop_last=False):
    if drop_last:
        fb_per_epoch = np.floor(dataset_len / int(batch_size))
    else:
        fb_per_epoch = np.ceil(dataset_len / int(batch_size))
    return int(fb_per_epoch * num_epochs)

def update_module(module, updates=None, memo=None):
    """
    Taken from https://github.com/learnables/learn2learn/blob/master/learn2learn/utils.py
    """
    if memo is None:
        memo = {}
    if updates is not None:
        params = list(module.parameters())
        if not len(updates) == len(list(params)):
            msg = 'WARNING:update_module(): Parameters and updates have different length. ('
            msg += str(len(params)) + ' vs ' + str(len(updates)) + ')'
            print(msg)
        for p, g in zip(params, updates):
            p.update = g

    # Update the params
    for param_key in module._parameters:
        p = module._parameters[param_key]
        if p is not None and hasattr(p, 'update') and p.update is not None:
            if p in memo:
                module._parameters[param_key] = memo[p]
            else:
                updated = p + p.update
                memo[p] = updated
                module._parameters[param_key] = updated

    # Second, handle the buffers if necessary
    for buffer_key in module._buffers:
        buff = module._buffers[buffer_key]
        if buff is not None and hasattr(buff, 'update') and buff.update is not None:
            if buff in memo:
                module._buffers[buffer_key] = memo[buff]
            else:
                updated = buff + buff.update
                memo[buff] = updated
                module._buffers[buffer_key] = updated

    # Then, recurse for each submodule
    for module_key in module._modules:
        module._modules[module_key] = update_module(
            module._modules[module_key],
            updates=None,
            memo=memo,
        )

    if hasattr(module, 'flatten_parameters'):
        module._apply(lambda x: x)
    return module

class ReLUSTE(torch.autograd.Function):
    def __init__(self):
        super(ReLUSTE, self).__init__()

    @staticmethod
    def forward(ctx, input):
        return F.relu(input)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

class Identity(torch.autograd.Function):
    # just identity, this class is defined for compatibility/scalability reasons
    def __init__(self):
        super(Identity, self).__init__()

    @staticmethod
    def forward(ctx, input):
        return input

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

class Step(torch.autograd.Function):
    def __init__(self):
        super(Step, self).__init__()

    @staticmethod
    def forward(ctx, input):
        return (input > 0.).long().float()

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

class SoftArgmax(torch.autograd.Function):
    def __init__(self):
        super(SoftArgmax, self).__init__()

    @staticmethod
    def forward(ctx, input):
        t = torch.argmax(F.softmax(input, dim=0), dim=0, keepdims=True)
        return torch.zeros_like(input).scatter_(0, t, 1.)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

def split_weight_decay_parameters(model):
    wd_params = []
    wd_params_names = []
    for n, m in model.named_modules():
        if allow_weight_decay(m):
            wd_params.append(m.weight)
            wd_params_names.append(f'{n}.weight')

    no_wd_params = [p for n, p in model.named_parameters() if n not in wd_params_names]
    assert len(wd_params) + len(no_wd_params) == len(list(model.parameters())), "Sanity check failed."
    return wd_params_names, wd_params, no_wd_params

def allow_weight_decay(module):
    return isinstance(module,
                        (nn.Linear,
                        nn.Conv1d,
                        nn.Conv2d,
                        nn.Conv3d,
                        nn.ConvTranspose1d,
                        nn.ConvTranspose2d,
                        nn.ConvTranspose3d)
    )

class Hypergrad:
    """
    Credit: "Optimizing Millions of Hyperparameters by Implicit Differentiation"
    (https://arxiv.org/pdf/1911.02590.pdf)
    """

    def __init__(self, learning_rate=.1, truncate_iter=3):
        self.learning_rate = learning_rate
        self.truncate_iter = truncate_iter

    def grad(self, loss_val, loss_train, meta_params, params):
        dloss_val_dparams = torch.autograd.grad(
            loss_val,
            params,
            retain_graph=True,
            allow_unused=True
        )
        
        dloss_train_dparams = torch.autograd.grad(
                loss_train,
                params,
                allow_unused=True,
                create_graph=True,
        )

        v2 = self._approx_inverse_hvp(dloss_val_dparams, dloss_train_dparams, params)

        v3 = torch.autograd.grad(
            dloss_train_dparams,
            meta_params,
            grad_outputs=v2,
            allow_unused=True
        )

        return list(-g for g in v3)

    def _approx_inverse_hvp(self, dloss_val_dparams, dloss_train_dparams, params):
        p = v = dloss_val_dparams

        for _ in range(self.truncate_iter):
            grad = torch.autograd.grad(
                    dloss_train_dparams,
                    params,
                    grad_outputs=v,
                    retain_graph=True,
                    allow_unused=True
                )

            grad = [g * self.learning_rate for g in grad]  # scale: this a is key for convergence

            v = [curr_v - curr_g for (curr_v, curr_g) in zip(v, grad)]
            p = [curr_p + curr_v for (curr_p, curr_v) in zip(p, v)]

        return list(pp for pp in p)
