import torch
import torch.nn as nn
import os
import glob
import time

def IMG_test(model,sc_config,loader,device,slicer_ftn=None):
    model.eval()
    total = 0
    loss = 0
    correct = 0
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.to(device), targets.to(device)
            total += targets.size(0)
            IF = model.split_edge_output(inputs,sc_config.split_layer)
            if slicer_ftn !=None:
                IF = slicer_ftn(IF,sc_config)
            outputs = model.split_cloud_output(IF,sc_config.split_layer)
            
            loss += criterion(outputs, targets).item()

            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()

    return 100. * correct / total, loss/total

def IMG_test_with_comp_time(model, sc_config, loader, device, slicer_ftn=None):
    model.eval()
    total = 0
    loss_accum = 0.0
    correct = 0
    criterion = nn.CrossEntropyLoss()

    bit_records = []     # list of lists: [batch0_bits, batch1_bits, …]
    time_records = []    # list of floats: slicing times per batch

    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            total += targets.size(0)

            # 1) run split and optional slicing, measuring time & bits
            IF = model.split_edge_output(inputs, sc_config.split_layer)
            if slicer_ftn is not None:
                t0 = time.time()
                out = slicer_ftn(IF, sc_config)
                t1 = time.time()
                time_records.append(t1 - t0)

                IF, bits = out
                bit_records.append(bits)

            # 2) forward and compute loss
            outputs = model.split_cloud_output(IF, sc_config.split_layer)
            loss_accum += criterion(outputs, targets).item()

            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()

    # aggregate accuracy and loss
    accuracy = 100. * correct / total
    avg_loss = loss_accum / total

    # ———— SLICER TIME STATS ————
    if time_records:
        time_tensor = torch.tensor(time_records, dtype=torch.float32)
        time_mean = time_tensor.mean().item()
        time_std  = time_tensor.std(unbiased=False).item()
        time_var  = time_tensor.var(unbiased=False).item()
    else:
        time_mean = time_std = time_var = 0.0

    # ———— BIT USAGE STATS ————
    if bit_records:
        # bit_records: [ [b0_0, b0_1, …], [b1_0, b1_1, …], … ]
        bit_tensor = torch.tensor(bit_records, dtype=torch.float32)
        bit_mean = bit_tensor.mean(dim=0).tolist()
        bit_std  = bit_tensor.std(dim=0, unbiased=False).tolist()
        bit_var  = bit_tensor.var(dim=0, unbiased=False).tolist()
    else:
        bit_mean = bit_std = bit_var = []

    return {
        "accuracy": accuracy,
        "avg_loss": avg_loss,
        # slicing time
        "time_mean": time_mean,
        "time_std":  time_std,
        "time_var":  time_var,
        # bit usage per partition
        "bit_mean": bit_mean,
        "bit_std":  bit_std,
        "bit_var":  bit_var,
    }
    
def IMG_get_IF(model, sc_config, loader,n_samples, save_path, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()

    model_name = model.__class__.__name__  # ex: 'ResNet', 'VGG'

    save_dir = os.path.join(save_path, model_name, f"SL{sc_config.split_layer}")
    os.makedirs(save_dir, exist_ok=True)

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.to(device), targets.to(device)
        
            IF = model.split_edge_output(inputs, sc_config.split_layer)
            
            IF_cpu = IF.detach().cpu()


            file_name = os.path.join(save_dir, f"{batch_idx}.pt")
            torch.save(IF_cpu, file_name)
            
            if batch_idx>=n_samples:
                break
            


def load_all_IF_features(save_path, model_name, split_layer):

    dir_path = os.path.join(save_path, model_name, f"SL{split_layer}")
    pt_files = glob.glob(os.path.join(dir_path, '*.pt'))


    def extract_index(file_path):
        file_name = os.path.basename(file_path)  
        idx_str = file_name.replace('.pt', '')   
        return int(idx_str)
    
    pt_files.sort(key=extract_index)

    IF_list = []
    for file_path in pt_files:
        data = torch.load(file_path)
        IF_list.append(data)

    return IF_list