import os
import numpy as np
import torch
import sklearn.metrics as skm


import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

def print_size(net):
    """
    Print the number of parameters of a network
    """

    if net is not None and isinstance(net, torch.nn.Module):
        module_parameters = filter(lambda p: p.requires_grad, net.parameters())
        params = sum([np.prod(p.size()) for p in module_parameters])
        print("{} Parameters: {:.6f}M".format(
            net.__class__.__name__, params / 1e6), flush=True)


class CustomDataset(Dataset):
    def __init__(self, tensor_data):
        self.tensor_data = tensor_data

    def __len__(self):
        return len(self.tensor_data)

    def __getitem__(self, idx):
        return self.tensor_data[idx]
    

def marginal_prob_std(t, sigma, device):
    t = torch.tensor(t, device=device)
    return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma, device):
    return torch.tensor(sigma**t, device=device)



def loss_fn_norm(model, x, marginal_prob_std, eps=1e-5):
    random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps 
    random_t = random_t * 999
    z = torch.randn_like(x)
    std = marginal_prob_std(random_t)
    perturbed_x = x + z * std[:, None, None]
    score = model(perturbed_x, random_t)
    loss = torch.mean(torch.sum((score + z)**2, dim=(1,2)))
    return loss

def get_data_paths(directory, seed):
    # List to hold paths of files named 'seed_0'
    import re

    def natural_sort_key(s):
        # Use a regular expression to split the string into numeric and non-numeric parts
        return [int(text) if text.isdigit() else text for text in re.split(r'(\d+)', s)]

    seed_files = []

    # Walk through the directory
    for root, dirs, files in os.walk(directory):
        for name in files:
            if name == f"seed_{seed}.pkl":
                # Construct full file path
                file_path = os.path.join(root, name)
                seed_files.append(file_path)

    return sorted(seed_files, key=natural_sort_key)

def low_density_anomalies(test_log_probs, num_anomalies):
    """ Helper function for the F1-score, selects the num_anomalies lowest values of test_log_prob
    """
    anomaly_indices = np.argpartition(test_log_probs, num_anomalies-1)[:num_anomalies]
    preds = np.zeros(len(test_log_probs))
    preds[anomaly_indices] = 1
    return preds

def calculate_scores(images, label):
    batch_size = images.shape[0]
    print(batch_size)

    preds = []
    
    img = images[:, 0, :, :]        
    summed_array = np.sum(img, axis=1)
    for i in summed_array:
        preds.append(i)
    
    preds = np.array(preds) 

    indices = np.arange(len(label))

    preds = np.squeeze(preds, axis=1)
    
    p = low_density_anomalies(-preds, len(indices[label==1]))

    f1_score = skm.f1_score(label, p)

    aucroc = skm.roc_auc_score(y_true=label, y_score=preds)

    aucpr = skm.average_precision_score(y_true=label, y_score=preds, pos_label=1)

    return f1_score, aucroc, aucpr, preds

class GaussianFourierProjection(nn.Module):
    def __init__(self, embed_dim, scale=30.):
        super().__init__()
        self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
    def forward(self, x):
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)