"""
Code for loading the datasets/architectures
"""

import os
import json
import math
import numpy as np
from collections import Counter
from sklearn.model_selection import train_test_split
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset, Subset, random_split, TensorDataset
from torch.utils.data import Subset

from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils import spectral_norm

import torchvision
import torchvision.transforms as transforms
from torchvision import datasets

from net.spectral_normalization.spectral_norm_conv_inplace import spectral_norm_conv
from net.spectral_normalization.spectral_norm_fc import spectral_norm_fc

import ddu_dirty_mnist
#import deeplake

def dataloaders_mnist_subsample(batch_size, val_size=0.2, imbalance_factor=0, subsample_size=None, noise=False):
    import random
    from torch.utils.data import Subset
    
    root = "/data"
    num_workers = 4
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    fmnist_dataset = datasets.FashionMNIST(root=root, download=True, train=False, transform=transform)
    kmnist_dataset = datasets.KMNIST(root=root, download=True, train=False, transform=transform)

    if not noise:
        mnist_trainval_dataset = datasets.MNIST(root=root, train=True, download=True, transform=transform)
        mnist_test_dataset = datasets.MNIST(root=root, train=False, download=True, transform=transform)

        if imbalance_factor == 0:
            if subsample_size is not None:
                # Subsample the training data to specified size
                indices = torch.randperm(len(mnist_trainval_dataset))[:subsample_size]
                train_subset = Subset(mnist_trainval_dataset, indices)
                val_len = int(val_size * subsample_size)
                train_len = subsample_size - val_len
                train_dataset, val_dataset = random_split(train_subset, [train_len, val_len])
            else:
                # Regular random split
                train_dataset, val_dataset = torch.utils.data.random_split(mnist_trainval_dataset, [48000, 12000])

            mnist_trainloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
            mnist_validloader = DataLoader(val_dataset, shuffle=False, batch_size=batch_size)

        else:
            long_tail_indices = create_long_tail_distribution(mnist_trainval_dataset, imbalance_factor)
            mnist_lt_train_dataset = Subset(mnist_trainval_dataset, long_tail_indices)

            train_size = int((1 - val_size) * len(mnist_lt_train_dataset))
            val_size = len(mnist_lt_train_dataset) - train_size

            train_subset, val_subset = random_split(mnist_lt_train_dataset, [train_size, val_size])
            balanced_val_subset = create_balanced_validation_set(val_subset, mnist_trainval_dataset)

            mnist_trainloader = DataLoader(train_subset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
            mnist_validloader = DataLoader(balanced_val_subset, batch_size=batch_size, shuffle=False)

        mnist_testloader = DataLoader(mnist_test_dataset, shuffle=False, batch_size=batch_size)

    else:
        dirty_mnist_train = ddu_dirty_mnist.DirtyMNIST(".", train=True, download=True)
        dirty_mnist_test = ddu_dirty_mnist.DirtyMNIST(".", train=False, download=True)

        train_size = int((1 - val_size) * len(dirty_mnist_train))
        val_size = len(dirty_mnist_train) - train_size

        train_subset, val_subset = random_split(dirty_mnist_train, [train_size, val_size])

        mnist_trainloader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        mnist_validloader = DataLoader(val_subset, batch_size=batch_size, shuffle=True)
        mnist_testloader = DataLoader(dirty_mnist_test, batch_size=128, shuffle=False)

    fmnist_loader = DataLoader(fmnist_dataset, batch_size=batch_size, shuffle=True)
    kmnist_loader = DataLoader(kmnist_dataset, batch_size=batch_size, shuffle=True)

    return mnist_trainloader, mnist_validloader, mnist_testloader, fmnist_loader, kmnist_loader

def dataloaders_cifar10_subsample(batch_size, val_size, imbalance_factor=0, subsample_size=None, noise=False):
    import random
    from torch.utils.data import Subset

    num_workers = 4
    root = "./data"
    
    # OOD transforms
    cifar100_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5074, 0.4867, 0.4411), (0.2675, 0.2565, 0.2761))
    ])
    svhn_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4377, 0.4438, 0.4728], std=[0.1980, 0.2010, 0.1970])
    ])
    
    # CIFAR-10 transforms
    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                     std=[0.2023, 0.1994, 0.2010])
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, 4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])
    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])
    test_transform = valid_transform
    
    # Load datasets
    if not noise:
        cifar10_full_train = datasets.CIFAR10(root=root, train=True, download=True, transform=train_transform)
        cifar10_valid_full = datasets.CIFAR10(root=root, train=True, download=True, transform=valid_transform)
    else:
        cifar10_full_train = CIFAR10N(noise_type='aggre_label', train=True, transform=train_transform)
        cifar10_valid_full = datasets.CIFAR10(root=root, train=True, download=True, transform=valid_transform)

    cifar10_test_dataset = datasets.CIFAR10(root=root, train=False, download=True, transform=test_transform)
    svhn_dataset = datasets.SVHN(root=root, split="test", download=True, transform=svhn_transform)
    cifar100_dataset = datasets.CIFAR100(root=root, train=False, download=True, transform=cifar100_transform)

    # Apply subsampling if required
    if imbalance_factor == 0:
        num_train = len(cifar10_full_train)
        indices = list(range(num_train))

        if subsample_size is not None:
            indices = torch.randperm(num_train)[:subsample_size]
            val_count = int(val_size * subsample_size)
        else:
            np.random.shuffle(indices)
            val_count = int(val_size * num_train)

        train_indices = indices[val_count:]
        val_indices = indices[:val_count]

        cifar10_train_dataset = Subset(cifar10_full_train, train_indices)
        cifar10_valid_dataset = Subset(cifar10_valid_full, val_indices)

        cifar10_trainloader = DataLoader(cifar10_train_dataset, batch_size=batch_size, shuffle=True)
        cifar10_validloader = DataLoader(cifar10_valid_dataset, batch_size=batch_size, shuffle=False)

    else:
        long_tail_indices = create_long_tail_distribution(cifar10_full_train, imbalance_factor)
        cifar10_lt_train_dataset = Subset(cifar10_full_train, long_tail_indices)

        train_size = int((1 - val_size) * len(cifar10_lt_train_dataset))
        val_size = len(cifar10_lt_train_dataset) - train_size

        train_subset, val_subset = random_split(cifar10_lt_train_dataset, [train_size, val_size])
        cifar10_valid_dataset = create_balanced_validation_set(val_subset, cifar10_full_train)

        cifar10_trainloader = DataLoader(train_subset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
        cifar10_validloader = DataLoader(cifar10_valid_dataset, batch_size=batch_size, shuffle=False)

    # Standard loaders
    cifar10_testloader = DataLoader(cifar10_test_dataset, batch_size=batch_size, shuffle=False)
    svhn_loader = DataLoader(svhn_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
    cifar100_loader = DataLoader(cifar100_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)

    return cifar10_trainloader, cifar10_validloader, cifar10_testloader, svhn_loader, cifar100_loader



def convert_to_native(val):
    if isinstance(val, np.generic):  # for NumPy scalars
        return val.item()
    elif isinstance(val, torch.Tensor):  # for PyTorch tensors
        return val.item() if val.numel() == 1 else val.tolist()
    return val

def create_long_tail_distribution(dataset, imbalance_factor):
    targets = np.array(dataset.targets)
    class_counts = Counter(targets)
    num_classes = len(class_counts)

    # Compute the number of samples for each class
    max_samples = max(class_counts.values())
    class_indices = []
    for cls in range(num_classes):
        num_samples = int(max_samples * (imbalance_factor ** (cls / (num_classes - 1.0))))
        indices = np.where(targets == cls)[0]
        np.random.shuffle(indices)
        class_indices.extend(indices[:num_samples])
    
    return class_indices





def create_balanced_validation_set(val_subset, train_dataset):
    val_indices = np.array(val_subset.indices)
    val_labels = np.array([train_dataset.targets[i] for i in val_indices])
    
    # Stratified split to maintain class distribution
    _, val_indices_bal = train_test_split(val_indices, test_size=0.5, stratify=val_labels)
    balanced_val_subset = Subset(train_dataset, val_indices_bal)
    
    return balanced_val_subset


def load_model(ID_dataset, pretrained, index, dropout_rate, spect_norm, device):
    if ID_dataset == "MNIST":
        if pretrained : 
            model = conv_net(dropout_rate, spect_norm)     
            #model.load_state_dict(torch.load(f"MNIST/saved_models/mnist_conv_daedl_{index+1}.pt" ))
        else : 
            model = conv_net(dropout_rate, spect_norm)             
            
    if ID_dataset == "CIFAR-10":
        num_classes = 10
        if pretrained : 
            model = vgg16(num_classes, dropout_rate, spect_norm)
            #model.load_state_dict(torch.load(f"CIFAR-10/saved_models/cifar10_vgg_daedl_{index+1}"))
            
        else :
            model = vgg16(num_classes, dropout_rate, spect_norm)   
            
    if ID_dataset == "CIFAR-100":
        if pretrained : 
            pass
        else : 
            #spectral_normalization = True
            model = resnet18(spect_norm)  
            
    model.to(device)
            
    return model

import torch
import torch.nn as nn

class FEDL(nn.Module):
    def __init__(self,
                 ID_dataset,
                 dropout_rate,
                 spect_norm,
                 device,
                 hidden_dim,
                 num_layers,
                 embedding_dim=None, fix_tau=False, fix_p=None):
        
        super().__init__()

        
        if ID_dataset == "MNIST":

            num_classes = 10
            backbone = conv_net(dropout_rate, spect_norm)
            
            f, g_alpha = backbone.convolutions, backbone.linear

            embedding_dim = 576

        elif ID_dataset == "CIFAR-10":
            num_classes = 10
            backbone = vgg16(num_classes, dropout_rate, spect_norm)
            
            f, g_alpha = backbone.features, backbone.classifier
            embedding_dim = 512

        elif ID_dataset == "CIFAR-100":
            num_classes = 100
            backbone = resnet18(spect_norm)
            f, g_alpha = backbone.features, backbone.classifier

            embedding_dim = 512
            
        else:
            raise ValueError(f"Unsupported dataset: {ID_dataset}")

        
        self.f = f
        self.g_alpha = g_alpha
        self.g_tau = MLP(embedding_dim, 1, hidden_dim,
                         num_layers, dropout_rate, spect_norm).to(device)
        self.g_p = MLP(embedding_dim, num_classes, hidden_dim,
                       num_layers, dropout_rate, spect_norm).to(device)


    def forward(self, x, fix_tau=False, fix_p=None):
        features = self.f(x)
        if isinstance(features, (tuple, list)):
            features = features[0]
        if len(features.shape) > 2:
            features = torch.flatten(features, 1)
    
        alpha = torch.exp(self.g_alpha(features))        
        p = F.softmax(self.g_p(features), dim=1)         
        tau = F.softplus(self.g_tau(features))           
    
        alpha0 = alpha.sum(dim=1, keepdim=True)          
    
        if fix_tau:
            tau = torch.ones_like(tau)                  
    
        if fix_p == "dirichlet":
            p = alpha / alpha0                           
            
        elif fix_p == "uniform":
            p = torch.full_like(p, 1.0 / alpha.size(1))  

        return alpha, p, tau




# Load Datasets
def load_datasets(ID_dataset, batch_size, val_size, imbalance_factor, noise):
    if ID_dataset == "MNIST":
        trainloader, validloader, testloader, ood_loader1, ood_loader2 = dataloaders_mnist(batch_size, val_size, imbalance_factor, noise)       
    if ID_dataset == "CIFAR-10":
        trainloader, validloader, testloader, ood_loader1, ood_loader2 = dataloaders_cifar10(batch_size, val_size, imbalance_factor, noise)       
    if ID_dataset == "CIFAR-100":
        trainloader, validloader, testloader, ood_loader1, ood_loader2 = dataloaders_cifar100(batch_size, val_size, imbalance_factor, noise)
            
    return trainloader, validloader, testloader, ood_loader1, ood_loader2


def dataloaders_mnist(batch_size, val_size, imbalance_factor, noise = False):
    root = "/data"
    num_workers = 4
    transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])

    transform_fmnist = transform
    transform_kmnist = transform
    #transform_emnist = transform
    #transform_notmnist = transform

    fmnist_dataset = datasets.FashionMNIST(root = root, download=True, train = False, transform = transform_fmnist)
    kmnist_dataset = datasets.KMNIST(root = root, download = True, train = False, transform = transform_kmnist)

    if noise == False:
        mnist_trainval_dataset = datasets.MNIST(root = root, train = True, download=True, transform=transform)
        mnist_test_dataset = datasets.MNIST(root = root, train = False, download=True, transform=transform)

        if imbalance_factor == 0 :           # MNIST 
            mnist_train_dataset, mnist_val_dataset = torch.utils.data.random_split(mnist_trainval_dataset, [48000, 12000])
            mnist_trainloader = DataLoader(mnist_train_dataset, shuffle = True, batch_size = batch_size)
            mnist_validloader = DataLoader(mnist_val_dataset, shuffle = True, batch_size = batch_size)

        else :                              # MNIST-LT
                long_tail_indices = create_long_tail_distribution(mnist_trainval_dataset, imbalance_factor)
                mnist_lt_train_dataset = Subset(mnist_trainval_dataset, long_tail_indices)

                train_size = int((1-val_size) * len(mnist_lt_train_dataset))
                val_size = len(mnist_lt_train_dataset) - train_size

                train_subset, val_subset = random_split(mnist_lt_train_dataset, [train_size, val_size])
                balanced_val_subset = create_balanced_validation_set(val_subset, mnist_trainval_dataset)

                mnist_trainloader = DataLoader(train_subset, batch_size = batch_size, num_workers = num_workers, shuffle = True)
                mnist_validloader = DataLoader(balanced_val_subset, batch_size = batch_size, shuffle = False)

        mnist_testloader = DataLoader(mnist_test_dataset, shuffle = False, batch_size = batch_size)
        
    else : 
        dirty_mnist_train = ddu_dirty_mnist.DirtyMNIST(".", train=True, download=True)
        dirty_mnist_test = ddu_dirty_mnist.DirtyMNIST(".", train=False, download=True)
        
        #dirty_mnist_train = ddu_dirty_mnist.AmbiguousMNIST(".", train=True, download=True)
        #dirty_mnist_test = ddu_dirty_mnist.AmbiguousMNIST(".", train=False, download=True)
        
        train_size = int((1-val_size) * len(dirty_mnist_train))
        val_size = len(dirty_mnist_train) - train_size
        
        train_subset, val_subset = random_split(dirty_mnist_train, [train_size, val_size])
    
        mnist_trainloader = DataLoader(train_subset, batch_size = batch_size, shuffle = True,)
        mnist_validloader = DataLoader(val_subset, batch_size = batch_size, shuffle = True,)
        mnist_testloader = DataLoader(dirty_mnist_test,batch_size=128, shuffle=False,)

    fmnist_loader = DataLoader(fmnist_dataset, batch_size = batch_size, shuffle = True)
    kmnist_loader = DataLoader(kmnist_dataset, batch_size = batch_size, shuffle = True)
    #notmnist_loader = DataLoader(notmnist_dataset, batch_size = batch_size, shuffle = True)

    return mnist_trainloader, mnist_validloader, mnist_testloader, fmnist_loader, kmnist_loader

def dataloaders_cifar10(batch_size, val_size, imbalance_factor, noise = False):
    num_workers = 4
    root = "./data"
    
    cifar100_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5074, 0.4867, 0.4411), (0.2675, 0.2565, 0.2761))])
    svhn_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.4377, 0.4438, 0.4728], std=[0.1980, 0.2010, 0.1970])])
                                        
    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
    train_transform = transforms.Compose([transforms.RandomCrop(32, 4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize,])

    valid_transform = transforms.Compose([transforms.ToTensor(), normalize,])
    test_transform = transforms.Compose([transforms.ToTensor(), normalize,])
    
    if noise == False:
        cifar10_train_dataset = datasets.CIFAR10(root = root, train = True, transform = train_transform)   
        cifar10_valid_dataset = datasets.CIFAR10(root = root, train = True, transform = valid_transform)

        
    else:
        cifar10_train_dataset = CIFAR10N(noise_type = 'aggre_label', train = True, transform = train_transform)
        #cifar10_valid_dataset = CIFAR10N(noise_type = 'aggre_label', train = True, transform = valid_transform)
        cifar10_valid_dataset = datasets.CIFAR10(root=root, train = True, transform = valid_transform)
        
    cifar10_test_dataset = datasets.CIFAR10(root = root, train = False, download = True, transform = test_transform) 
    
    svhn_dataset = datasets.SVHN(root=root, split="test", download=True, transform = svhn_transform)
    cifar100_dataset = datasets.CIFAR100(root=root, train=False, download=True, transform = cifar100_transform)
    
    if imbalance_factor == 0:
        num_train = len(cifar10_train_dataset)
        indices = list(range(num_train))
        split = int(np.floor(val_size * num_train))
        np.random.shuffle(indices)

        cifar10_train_dataset = Subset(cifar10_train_dataset, indices[split:])
        cifar10_valid_dataset = Subset(cifar10_valid_dataset, indices[:split])
        
        cifar10_trainloader = DataLoader(cifar10_train_dataset, batch_size=batch_size, shuffle=True)
        cifar10_validloader = DataLoader(cifar10_valid_dataset, batch_size=batch_size, shuffle=False)
                
    else:
        long_tail_indices = create_long_tail_distribution(cifar10_train_dataset, imbalance_factor)
        cifar10_lt_train_dataset = Subset(cifar10_train_dataset, long_tail_indices)

        train_size = int((1 - val_size) * len(cifar10_lt_train_dataset))
        val_size = len(cifar10_lt_train_dataset) - train_size

        train_subset, val_subset = random_split(cifar10_lt_train_dataset, [train_size, val_size])
        cifar10_valid_dataset = create_balanced_validation_set(val_subset, cifar10_train_dataset)
    
        cifar10_trainloader = DataLoader(train_subset, batch_size = batch_size, num_workers = num_workers, shuffle=True)
        cifar10_validloader = DataLoader(cifar10_valid_dataset, batch_size=batch_size, shuffle=False)

    cifar10_testloader = DataLoader(cifar10_test_dataset, batch_size=batch_size, shuffle=False)   
    svhn_loader = DataLoader(svhn_dataset, batch_size=batch_size, shuffle=False)
    cifar100_loader = DataLoader(cifar100_dataset, batch_size=batch_size, shuffle=False)

    svhn_loader = DataLoader(svhn_dataset, batch_size=batch_size, num_workers = num_workers, shuffle = False)
    cifar100_loader = DataLoader(cifar100_dataset, batch_size=batch_size, num_workers = num_workers, shuffle = False)

    return cifar10_trainloader, cifar10_validloader, cifar10_testloader, svhn_loader, cifar100_loader


def dataloaders_cifar100(batch_size, val_size, imbalance_factor, noise):
    num_workers = 8
    root = "./data"    
    normalize = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    
    resize = transforms.Resize((32, 32))
    tensorize = transforms.ToTensor()
    
    #train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), tensorize,normalize,])
    
    train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    #transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
    
    #train_transform = transforms.Compose([tensorize, normalize])

    valid_transform = transforms.Compose([tensorize, normalize])
    test_transform = transforms.Compose([tensorize, normalize])
    
    fmnist2cifar = transforms.Compose([resize, transforms.Grayscale(num_output_channels=3), tensorize, transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
                                           
    tin2cifar = transforms.Compose([resize, tensorize, transforms.Normalize(mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2255])])
    svhn_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.4377, 0.4438, 0.4728], std=[0.1980, 0.2010, 0.1970])])
    
       
    if noise == False:
        cifar100_train_dataset = datasets.CIFAR100(root = root, train = True, download = True, transform = train_transform)   
        cifar100_valid_dataset = datasets.CIFAR100(root = root, train = True, download = True, transform = valid_transform)
        
    else : 
        cifar100_train_dataset = CIFAR100N(noise_type = 'aggre_label', train = True, transform = train_transform)
        cifar100_valid_dataset = CIFAR100N(noise_type = 'aggre_label', train = True, transform = valid_transform)    
    
    num_train = len(cifar100_train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(val_size * num_train))
    np.random.shuffle(indices)
    
    cifar100_train_dataset = Subset(cifar100_train_dataset, indices[split:])
    cifar100_valid_dataset = Subset(cifar100_valid_dataset, indices[:split])
    cifar100_test_dataset = datasets.CIFAR100(root = root, train = False, download=True, transform = test_transform)   
    
    svhn_dataset = datasets.SVHN(root = root, split="test", download=True, transform = svhn_transform)
    fmnist_test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform = fmnist2cifar)
    tin_test_dataset = datasets.ImageFolder(root="./data/tiny-imagenet-200/test", transform = tin2cifar)

    cifar100_trainloader = DataLoader(cifar100_train_dataset, batch_size = batch_size, num_workers = num_workers, persistent_workers=True, pin_memory = True, shuffle=True)
    cifar100_validloader = DataLoader(cifar100_valid_dataset, batch_size = batch_size, num_workers = num_workers, persistent_workers=True, pin_memory = True, shuffle=False) 
    cifar100_testloader =  DataLoader(cifar100_test_dataset, batch_size=batch_size, num_workers = num_workers, persistent_workers=True, pin_memory = True,shuffle = False)
    
    svhn_loader = DataLoader(svhn_dataset, batch_size = batch_size, num_workers = num_workers)
    #fmnist_loader = DataLoader(fmnist_test_dataset, batch_size = batch_size, num_workers = num_workers)
    tin_loader = DataLoader(tin_test_dataset, batch_size = batch_size, num_workers = num_workers)

    return cifar100_trainloader, cifar100_validloader, cifar100_testloader, svhn_loader, tin_loader


class AvgPoolShortCut(nn.Module):
    def __init__(self, stride, out_c, in_c):
        super(AvgPoolShortCut, self).__init__()
        self.stride = stride
        self.out_c = out_c
        self.in_c = in_c

    def forward(self, x):
        if x.shape[2] % 2 != 0:
            x = F.avg_pool2d(x, 1, self.stride)
        else:
            x = F.avg_pool2d(x, self.stride, self.stride)
        pad = torch.zeros(x.shape[0], self.out_c - self.in_c, x.shape[2], x.shape[3], device=x.device,)
        x = torch.cat((x, pad), dim=1)
        return x

class SpectralLinear(nn.Module):
    def __init__(self, input_dim, output_dim, k_lipschitz=1.0):
        super().__init__()
        self.k_lipschitz = k_lipschitz
        self.spectral_linear = spectral_norm(nn.Linear(input_dim, output_dim))

    def forward(self, x):
        y = self.k_lipschitz * self.spectral_linear(x)
        return y


class SpectralConv(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_dim, padding, k_lipschitz=1.0):
        super().__init__()
        self.k_lipschitz = k_lipschitz
        self.spectral_conv = spectral_norm(nn.Conv2d(input_dim, output_dim, kernel_dim, padding=padding))

    def forward(self, x):
        y = self.k_lipschitz * self.spectral_conv(x)
        return y


def linear_sequential(input_dims, hidden_dims, output_dim, k_lipschitz=None, p_drop=None):
    dims = [np.prod(input_dims)] + hidden_dims + [output_dim]
    num_layers = len(dims) - 1
    layers = []
    for i in range(num_layers):
        if k_lipschitz is not None:
            l = SpectralLinear(dims[i], dims[i + 1], k_lipschitz ** (1./num_layers))
            layers.append(l)
        else:
            layers.append(nn.Linear(dims[i], dims[i + 1]))
        if i < num_layers - 1:
            layers.append(nn.ReLU())
            if p_drop is not None:
                layers.append(nn.Dropout(p=p_drop))
    return nn.Sequential(*layers)

def convolution_sequential(input_dims, hidden_dims, output_dim, kernel_dim, k_lipschitz=None, p_drop=None):
    channel_dim = input_dims[2]
    dims = [channel_dim] + hidden_dims
    num_layers = len(dims) - 1
    layers = []
    for i in range(num_layers):
        if k_lipschitz is not None:
            l = SpectralConv(dims[i], dims[i + 1], kernel_dim, (kernel_dim - 1) // 2, k_lipschitz ** (1./num_layers))
            layers.append(l)
        else:
            layers.append(nn.Conv2d(dims[i], dims[i + 1], kernel_dim, padding=(kernel_dim - 1) // 2))
        layers.append(nn.ReLU())
        if p_drop is not None:
            layers.append(nn.Dropout(p=p_drop))
        layers.append(nn.MaxPool2d(2, padding=0))
    return nn.Sequential(*layers)


class ConvLinSeq(nn.Module):
    def __init__(self, input_dims, linear_hidden_dims, conv_hidden_dims, output_dim, kernel_dim, batch_size, k_lipschitz, p_drop):
        super().__init__()
        if k_lipschitz is not None:
            k_lipschitz = k_lipschitz ** (1./2.)
        self.convolutions = convolution_sequential(input_dims=input_dims,
                                                   hidden_dims=conv_hidden_dims,
                                                   output_dim=output_dim,
                                                   kernel_dim=kernel_dim,
                                                   k_lipschitz=k_lipschitz,
                                                   p_drop=p_drop)
        
        # We assume that conv_hidden_dims is a list of same hidden_dim values
        self.linear = linear_sequential(input_dims=[conv_hidden_dims[-1] * (input_dims[0] // 2 ** len(conv_hidden_dims)) * (input_dims[1] // 2 ** len(conv_hidden_dims))],
                                        hidden_dims=linear_hidden_dims,
                                        output_dim=output_dim,
                                        k_lipschitz=k_lipschitz,
                                        p_drop=p_drop)

    def forward(self, input):
        batch_size = input.size(0)
        input = self.convolutions(input)
        self.feature = input.clone().detach().reshape(batch_size,-1)

        input = self.linear(input.reshape(batch_size, -1))
        return input


def convolution_linear_sequential(input_dims, linear_hidden_dims, conv_hidden_dims, output_dim, kernel_dim, batch_size, k_lipschitz, p_drop=None):
    return ConvLinSeq(input_dims=input_dims,
                      linear_hidden_dims=linear_hidden_dims,
                      conv_hidden_dims=conv_hidden_dims,
                      output_dim=output_dim,
                      kernel_dim=kernel_dim, batch_size = batch_size,
                      k_lipschitz=k_lipschitz,
                      p_drop=p_drop)


import torch.nn as nn
import math

class VGG(nn.Module):
    '''
    VGG model with modifications
    '''
    def __init__(self, features, output_dim, p_drop, k_lipschitz=None): 
        super(VGG, self).__init__()
        self.features = features
        
        if k_lipschitz is not None:           
            self.classifier = nn.Sequential(
                nn.Dropout(p=p_drop),
                SpectralLinear(512, 256, k_lipschitz),
                nn.ReLU(True),
                nn.BatchNorm1d(256),  
                nn.Dropout(p=p_drop),
                SpectralLinear(256, output_dim, k_lipschitz),)
                 
        else:
            self.classifier = nn.Sequential(
                nn.Dropout(p=p_drop),
                nn.Linear(512, 256),  
                nn.ReLU(True),
                nn.BatchNorm1d(256),  
                nn.Dropout(p=p_drop),
                nn.Linear(256, output_dim),
            )
        
        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = nn.functional.adaptive_avg_pool2d(x, (1, 1))  # Global Average Pooling
        x = x.view(x.size(0), -1)
        self.feature = x.clone().detach().reshape(x.shape[0], -1)
        x = self.classifier(x)
        return x


    
def make_layers(cfg, dropout_rate, batch_norm=False, k_lipschitz=None):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2), nn.Dropout(p = dropout_rate)]
        else:
            if k_lipschitz is not None:
                conv2d = SpectralConv(in_channels, v, kernel_dim=3, padding=1, k_lipschitz=k_lipschitz)
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)




def vgg16_bn(output_dim, p_drop, k_lipschitz=None):
    """VGG 16-layer model (configurationF "D") with batch normalization and modifications"""
    if k_lipschitz is not None:
        k_lipschitz = k_lipschitz ** (1. / 16.)
    return VGG(make_layers(cfg['D'], p_drop, batch_norm=True, k_lipschitz=k_lipschitz),
               output_dim=output_dim,p_drop=p_drop,
               k_lipschitz=k_lipschitz,
               )


cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
          512, 512, 512, 512, 'M'],
}


# Main Architectures : ConvNet / VGG16 / ResNet18
def conv_net(p_drop, spect_norm = False):        
    input_dims = [28, 28, 1]
    linear_hidden_dims =[64, 64]
    conv_hidden_dims = [64, 64, 64]
    output_dim = 10
    kernel_dim = 5
    batch_size = 64
    
    k_lipschitz = None
    if spect_norm == True:
        k_lipschitz = 1


    return convolution_linear_sequential(input_dims, linear_hidden_dims, conv_hidden_dims, output_dim, kernel_dim, batch_size, k_lipschitz, p_drop)

def vgg16(output_dim, p_drop, spect_norm):               
    if spect_norm == True:
        k_lipschitz = 1
    else :
        k_lipschitz = None
        
    return vgg16_bn(output_dim, p_drop = p_drop, k_lipschitz = k_lipschitz)


def resnet(num_classes):
    return resnet18(num_classes = num_classes, spectral_normalization = True, mod=True, temp=1.0, mnist=False)


"""
def vgg16(output_dim, k_lipschitz=None, p_drop=.5):
    VGG 16-layer model (configuration "D")
    if k_lipschitz is not None:
        k_lipschitz = k_lipschitz ** (1. / 16.)
    return VGG(make_layers(cfg['D'], k_lipschitz=k_lipschitz),
               output_dim=output_dim,
               k_lipschitz=k_lipschitz,
               p_drop=p_drop)
               
"""

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, num_layers, dropout_rate, spect_norm):
        super(MLP, self).__init__()
        
        layers = []
        
        if num_layers == 0:
            layers.append(nn.Linear(input_dim, output_dim))      
        else:
            layers.append(nn.Dropout(p = dropout_rate))
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.BatchNorm1d(hidden_dim))

            for _ in range(num_layers - 2):
                layers.append(nn.Dropout(p = dropout_rate))  
                layers.append(nn.Linear(hidden_dim, hidden_dim))  
                layers.append(nn.ReLU(inplace=True))
                layers.append(nn.BatchNorm1d(hidden_dim))

            layers.append(nn.Dropout(p = dropout_rate))
            layers.append(nn.Linear(hidden_dim, output_dim))
           
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm
import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, num_layers,
                 dropout_rate=0.0, spect_norm=False):
        super(MLP, self).__init__()
        layers = []

        if num_layers == 0:
            final_linear = nn.Linear(input_dim, output_dim)
            layers.append(spectral_norm(final_linear) if spect_norm else final_linear)

        else:
            layers.append(nn.Dropout(p=dropout_rate))
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.LayerNorm(hidden_dim))

            for _ in range(num_layers - 2):
                layers.append(nn.Dropout(p=dropout_rate))
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                layers.append(nn.ReLU(inplace=True))
                layers.append(nn.LayerNorm(hidden_dim))

            final_linear = nn.Linear(hidden_dim, output_dim)
            final_linear = spectral_norm(final_linear) if spect_norm else final_linear
            layers.append(final_linear)

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

    
    
def get_result(folder_path, experiment_type, num_experiments, f_edl, dist_shift):
    
    files = [f for f in os.listdir(folder_path) if f.startswith(experiment_type) and f.endswith('.json')]
    files.sort(key=lambda f: os.path.getmtime(os.path.join(folder_path, f)), reverse=True)

    last_ten_files = files[:num_experiments]

    results = []
    for file_name in last_ten_files:
        with open(os.path.join(folder_path, file_name), 'r') as file:
            data = json.load(file)
            results.append(data)

    test_acc = np.array([results[i]["Test Accuracy"] for i in range(len(results))])
    conf_aupr_au = np.array([results[i]["CONF AUPR"]["AU"] for i in range(len(results))])
    ood_aupr_svhn_au = np.array([results[i]["OOD AUPR"][0]["AU"] for i in range(len(results))])
    ood_aupr_svhn_du = np.array([results[i]["OOD AUPR"][0]["DU"] for i in range(len(results))])
    ood_aupr_c100_au = np.array([results[i]["OOD AUPR"][1]["AU"] for i in range(len(results))])
    ood_aupr_c100_du = np.array([results[i]["OOD AUPR"][1]["DU"] for i in range(len(results))])
    
    if f_edl == True:
        ood_aupr_svhn_du = np.array([results[i]["OOD AUPR"][0]["DU"] for i in range(len(results))])
        ood_aupr_svhn_tu = np.array([results[i]["OOD AUPR"][0]["TU"] for i in range(len(results))])
        ood_aupr_c100_du = np.array([results[i]["OOD AUPR"][1]["DU"] for i in range(len(results))])
        ood_aupr_c100_tu = np.array([results[i]["OOD AUPR"][1]["TU"] for i in range(len(results))])    

    if dist_shift == True:
        dist_1_aupr_du = np.array([results[i]["DIST AUPR"]["1"]["AU"] for i in range(len(results))])
        dist_2_aupr_du = np.array([results[i]["DIST AUPR"]["2"]["AU"] for i in range(len(results))])
        dist_3_aupr_du = np.array([results[i]["DIST AUPR"]["3"]["AU"] for i in range(len(results))])
        dist_4_aupr_du = np.array([results[i]["DIST AUPR"]["4"]["AU"] for i in range(len(results))])
        dist_5_aupr_du = np.array([results[i]["DIST AUPR"]["5"]["AU"] for i in range(len(results))])

        df_summary = pd.DataFrame({
            'Metric': ['Mean', 'Std'],
            'TEST ACC':[(test_acc.mean()).round(2), (test_acc.std()).round(2)],
            'CONF AUPR': [(conf_aupr_au.mean() * 100).round(2), (conf_aupr_au.std() * 100).round(2)],
            'OOD AUPR SVHN AU': [(ood_aupr_svhn_au.mean() * 100).round(2), (ood_aupr_svhn_au.std() * 100).round(2)],
            'OOD AUPR SVHN DU': [(ood_aupr_svhn_du.mean() * 100).round(2), (ood_aupr_svhn_du.std() * 100).round(2)],
            'OOD AUPR C100 AU': [(ood_aupr_c100_au.mean() * 100).round(2), (ood_aupr_c100_au.std() * 100).round(2)],
            'OOD AUPR C100 DU': [(ood_aupr_c100_du.mean() * 100).round(2), (ood_aupr_c100_du.std() * 100).round(2)],
            'DIST AUPR C=1': [(dist_1_aupr_du.mean() * 100).round(2), (dist_1_aupr_du.std() * 100).round(2)],
            'DIST AUPR C=2': [(dist_2_aupr_du.mean() * 100).round(2), (dist_2_aupr_du.std() * 100).round(2)],
            'DIST AUPR C=3': [(dist_3_aupr_du.mean() * 100).round(2), (dist_3_aupr_du.std() * 100).round(2)],
            'DIST AUPR C=4': [(dist_4_aupr_du.mean() * 100).round(2), (dist_4_aupr_du.std() * 100).round(2)],
            'DIST AUPR C=5': [(dist_5_aupr_du.mean() * 100).round(2), (dist_5_aupr_du.std() * 100).round(2)],
        })
    
    else:
        df_summary = pd.DataFrame({'Metric': ['Mean', 'Std'], 'TEST ACC':[(test_acc.mean()).round(2), (test_acc.std()).round(2)],
    'CONF AUPR': [(conf_aupr_au.mean() * 100).round(2), (conf_aupr_au.std() * 100).round(2)],
    'OOD AUPR SVHN AU': [(ood_aupr_svhn_au.mean() * 100).round(2), (ood_aupr_svhn_au.std() * 100).round(2)],
    'OOD AUPR SVHN DU': [(ood_aupr_svhn_du.mean() * 100).round(2), (ood_aupr_svhn_du.std() * 100).round(2)],
    'OOD AUPR C100 AU': [(ood_aupr_c100_au.mean() * 100).round(2), (ood_aupr_c100_au.std() * 100).round(2)],
    'OOD AUPR C100 DU': [(ood_aupr_c100_du.mean() * 100).round(2), (ood_aupr_c100_du.std() * 100).round(2)],})
        
    display(df_summary)
    
    return test_acc

    
    
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

from net.spectral_normalization.spectral_norm_conv_inplace import spectral_norm_conv
from net.spectral_normalization.spectral_norm_fc import spectral_norm_fc


class ResNet(nn.Module):
    def __init__(
        self,
        block,
        num_blocks,
        num_classes=100,
        temp=1.0,
        spectral_normalization=True,
        mod=True,
        coeff=3,
        n_power_iterations=1,
        mnist=False,
        dropout_rate=0,  # New dropout rate parameter
    ):

        super(ResNet, self).__init__()
        self.in_planes = 64
        self.mod = mod

        def wrapped_conv(input_size, in_c, out_c, kernel_size, stride):
            padding = 1 if kernel_size == 3 else 0
            conv = nn.Conv2d(in_c, out_c, kernel_size, stride, padding, bias=False)

            if spectral_normalization:
            	conv = torch.nn.utils.spectral_norm(conv)
            return conv

        self.wrapped_conv = wrapped_conv

        self.bn1 = nn.BatchNorm2d(64)
        self.dropout = nn.Dropout(p=dropout_rate)  # Dropout layer

        if mnist:
            self.conv1 = wrapped_conv(28, 1, 64, kernel_size=3, stride=1)
            self.layer1 = self._make_layer(block, 28, 64, num_blocks[0], stride=1)
            self.layer2 = self._make_layer(block, 28, 128, num_blocks[1], stride=2)
            self.layer3 = self._make_layer(block, 14, 256, num_blocks[2], stride=2)
            self.layer4 = self._make_layer(block, 7, 512, num_blocks[3], stride=2)
        else:
            self.conv1 = wrapped_conv(32, 3, 64, kernel_size=3, stride=1)
            self.layer1 = self._make_layer(block, 32, 64, num_blocks[0], stride=1)
            self.layer2 = self._make_layer(block, 32, 128, num_blocks[1], stride=2)
            self.layer3 = self._make_layer(block, 16, 256, num_blocks[2], stride=2)
            self.layer4 = self._make_layer(block, 8, 512, num_blocks[3], stride=2)

        self.avgpool = nn.AvgPool2d(kernel_size=4)
        self.features = nn.Sequential(
            self.conv1,
            self.bn1,
            self.layer1,
            self.dropout,  # Dropout after layer1
            self.layer2,
            self.dropout,  # Dropout after layer2
            self.layer3,
            self.dropout,  # Dropout after layer3
            self.layer4,
            self.dropout,  # Dropout after layer4
            self.avgpool,
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout_rate),  # Dropout before the classifier
            nn.Linear(512 * block.expansion, num_classes),
        )
        self.activation = F.leaky_relu if self.mod else F.relu
        self.feature = None
        self.temp = temp

    def _make_layer(self, block, input_size, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(input_size, self.wrapped_conv, self.in_planes, planes, stride, self.mod))
            self.in_planes = planes * block.expansion
            input_size = math.ceil(input_size / stride)
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        self.feature = x.clone().detach()
        x = self.classifier(x) / self.temp
        return x
    
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, input_size, wrapped_conv, in_planes, planes, stride=1, mod=True, dropout_rate=0):
        super(BasicBlock, self).__init__()
        self.conv1 = wrapped_conv(input_size, in_planes, planes, kernel_size=3, stride=stride)
        self.bn1 = nn.BatchNorm2d(planes)
        #self.dropout1 = nn.Dropout(p=dropout_rate)  # Dropout after the first convolution

        self.conv2 = wrapped_conv(math.ceil(input_size / stride), planes, planes, kernel_size=3, stride=1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.dropout2 = nn.Dropout(p=dropout_rate)  # Dropout after the second convolution

        self.mod = mod
        self.activation = F.leaky_relu if self.mod else F.relu

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            if mod:
                self.shortcut = nn.Sequential(AvgPoolShortCut(stride, self.expansion * planes, in_planes))
            else:
                self.shortcut = nn.Sequential(
                    wrapped_conv(input_size, in_planes, self.expansion * planes, kernel_size=1, stride=stride),
                    nn.BatchNorm2d(self.expansion * planes),
                )

    def forward(self, x):
        out = self.activation(self.bn1(self.conv1(x)))
        #out = self.dropout1(out)  # Apply dropout after the first activation
        out = self.bn2(self.conv2(out))
        out = self.dropout2(out)  # Apply dropout after the second activation
        out += self.shortcut(x)
        out = self.activation(out)
        return out


def resnet18(spectral_normalization, mod=True, temp=1.0, mnist=False, dropout_rate=0, **kwargs):
    model = ResNet(
        BasicBlock,
        [2, 2, 2, 2],
        spectral_normalization=spectral_normalization,
        mod=mod,
        temp=temp,
        mnist=mnist,
        dropout_rate=dropout_rate,  # Pass the dropout rate
        **kwargs
    )
    return model



def resnet50(spectral_normalization=True, mod=True, temp=1.0, mnist=False, **kwargs):
    model = ResNet(
        Bottleneck,
        [3, 4, 6, 3],
        spectral_normalization=spectral_normalization,
        mod=mod,
        temp=temp,
        mnist=mnist,
        **kwargs
    )
    return model


def resnet101(spectral_normalization=True, mod=True, temp=1.0, mnist=False, **kwargs):
    model = ResNet(
        Bottleneck,
        [3, 4, 23, 3],
        spectral_normalization=spectral_normalization,
        mod=mod,
        temp=temp,
        mnist=mnist,
        **kwargs
    )
    return model


def resnet110(spectral_normalization=True, mod=True, temp=1.0, mnist=False, **kwargs):
    model = ResNet(
        Bottleneck,
        [3, 4, 26, 3],
        spectral_normalization=spectral_normalization,
        mod=mod,
        temp=temp,
        mnist=mnist,
        **kwargs
    )
    return model


def resnet152(spectral_normalization=True, mod=True, temp=1.0, mnist=False, **kwargs):
    model = ResNet(
        Bottleneck,
        [3, 8, 36, 3],
        spectral_normalization=spectral_normalization,
        mod=mod,
        temp=temp,
        mnist=mnist,
        **kwargs
    )
    return model