import torch
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from torch import nn, optim
import glob
import numpy as np
import pandas as pd

from src import save_load_utils as sl
from src import plotting_functions as pf
from src import train

def prune_model(model, amount=0.5):
    """
    Apply global unstructured pruning to all Conv2d + Linear layers.
    """
    parameters_to_prune = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            parameters_to_prune.append((module, 'weight'))

    # Global pruning: keep only (1 - amount) of the weights
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=amount,
    )

    return model

def finetune(model, train_loader, val_loader, device, epochs=10, lr=1e-3):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)

    model.to(device)

    for epoch in range(epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0

        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        train_acc = 100. * correct / total
        train_loss = running_loss / total

        # Validation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        val_acc = 100. * correct / total

        print(f"Epoch {epoch+1}/{epochs}: "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")

    return model


txt_files = glob.glob("./runs/CIFAR10/ResNet18/SGD/*.txt") 
test_accuracies = []
sparsities = []

for f in txt_files:
    last_slash = f.rfind('/')
    last_dot = f.rfind('.')

    ROOT = f[:last_slash+1]
    LOAD_NAME = f[last_slash+1:last_dot]

    conf, reader, model = sl.load_results(LOAD_NAME, ROOT, include_model=True)

    print('Model has been loaded')

    train_dataloader, val_dataloader, test_dataloader = train.create_datasets_and_loader(dataset='CIFAR10', batch_size=128, root='./data', 
                                                                                        download=True, max_samples=None, train_split=0.9, 
                                                                                        return_datasets = False,
                                                                                        distributed=False)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    model = prune_model(model, amount=0.8) 

    model = finetune(model, train_dataloader, val_dataloader, device, epochs=20, lr=1e-2)

    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for x, y in test_dataloader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            _, predicted = outputs.max(1)
            test_total += y.size(0)
            test_correct += predicted.eq(y).sum().item()

    test_acc = 100.0 * test_correct / test_total
    print(f"Test Accuracy: {test_acc:.2f}%")

    test_accuracies.append(test_acc)

    total = 0
    zero = 0
    for module in model.modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            mask = getattr(module, 'weight_mask', None)
            if mask is not None:
                total += mask.numel()
                zero += torch.sum(mask == 0).item()
    
    sparsity = zero/total*100
    print(f"Sparsity: {sparsity:.2f}%")

    sparsities.append(sparsity)

    print('-------------------------------------------')

print('-------------------------------------------')

print(f'Test accuracy mean: {np.mean(test_accuracies)}')
print(f'Test accuracy std: {np.std(test_accuracies, ddof=1)}')

print('-------------------------------------------')

print(f'Sparsity mean: {np.mean(sparsities)}')
print(f'Sparsity std: {np.std(sparsities, ddof=1)}')

