import torch

def compute_layer_importance(model, loss_function, data_loader, optimizer, device):
    importance_dict = {}
    model.train()

    for batch in data_loader:
        if len(batch) == 3:
            input, target, _ = batch
        elif len(batch) == 2:
            input, target = batch
        else:
            raise ValueError(f"Unexpected batch format with {len(batch)} elements")

        input, target = input.to(device), target.to(device)

        optimizer.zero_grad()

        output = model(input)
        loss = loss_function(output, target)

        loss.backward()

        for name, param in model.named_parameters():
            if param.requires_grad:
                grad_change = torch.abs(param.grad).sum().item()
                param_change = torch.abs(param.data - param.grad).sum().item()
                importance_dict[name] = grad_change + param_change

    return importance_dict

def freeze_layers(model, importance_dict, threshold=0.1):
    for name, param in model.named_parameters():
        if name in importance_dict and (1 - importance_dict[name]) < threshold:
            param.requires_grad = False
            print(f"Layer {name} frozen.")
        else:
            param.requires_grad = True


def save_model(model, save_path, importance_dict, threshold=0.1):
    frozen_params = {}
    non_frozen_params = {}

    for name, param in model.named_parameters():
        if name in importance_dict and importance_dict[name] < threshold:
            frozen_params[name] = param.data.cpu()
        else:
            non_frozen_params[name] = param.data.cpu()

    torch.save(frozen_params, f"{save_path}/frozen_layers.pth")
    torch.save(non_frozen_params, f"{save_path}/non_frozen_layers.pth")
    # print(f"Model saved at {save_path}")

def load_model(model, load_path, importance_dict, threshold=0.1):
    frozen_params = torch.load(f"{load_path}/frozen_layers.pth")
    non_frozen_params = torch.load(f"{load_path}/non_frozen_layers.pth")

    for name, param in model.named_parameters():
        if name in non_frozen_params:
            param.data.copy_(non_frozen_params[name])

    for name, param in model.named_parameters():
        if name in frozen_params:
            param.data.copy_(frozen_params[name])

    # print(f"Model loaded from {load_path}")

def training_process(model, data_loader, loss_function, optimizer, device, save_path, args):
    importance_dict = compute_layer_importance(model, loss_function, data_loader, device)

    threshold = args.freeze_threshold
    freeze_layers(model, importance_dict, threshold)

    if args.target == 'learning':
        save_model(model, save_path, importance_dict, threshold)
