##################################
# Acknowledgment:
# Part of this code is adopted from https://github.com/kahnchana/opl
# Part of this code is adopted from https://www.kaggle.com/code/yiweiwangau/cifar-100-resnet-pytorch-75-17-accuracy
##################################

# Commented out IPython magic to ensure Python compatibility.
import pandas as pd
import os
import torch
import time
import tqdm
import torchvision
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torch.utils.data import random_split
from torchvision.utils import make_grid
import torchvision.models as models
import matplotlib.pyplot as plt
from sklearn.metrics import *

#########################################

def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(42)

#########################################

# %matplotlib inline

batch_size = 400
epochs = 120
max_lr = 0.001
grad_clip = 0.01
weight_decay =0.001
opt_func = torch.optim.Adam

train_data = torchvision.datasets.CIFAR100('./', train=True, download=True)

# Stick all the images together to form a 1600000 X 32 X 3 array
x = np.concatenate([np.asarray(train_data[i][0]) for i in range(len(train_data))])

# calculate the mean and std along the (0, 1) axes
mean = np.mean(x, axis=(0, 1))/255
std = np.std(x, axis=(0, 1))/255
# the the mean and std
mean=mean.tolist()
std=std.tolist()

transform_train = tt.Compose([tt.RandomCrop(32, padding=4,padding_mode='reflect'), 
                         tt.RandomHorizontalFlip(), 
                         tt.ToTensor(), 
                         tt.Normalize(mean,std,inplace=True)])
transform_test = tt.Compose([tt.ToTensor(), tt.Normalize(mean,std)])

trainset = torchvision.datasets.CIFAR100("./",
                                         train=True,
                                         download=True,
                                         transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size, shuffle=True, num_workers=2,pin_memory=True)

testset = torchvision.datasets.CIFAR100("./",
                                        train=False,
                                        download=True,
                                        transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size*2,pin_memory=True, num_workers=2)

"""# Device check and load model into device"""

def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)


device = get_default_device()
device

trainloader = DeviceDataLoader(trainloader, device)
testloader = DeviceDataLoader(testloader, device)

"""# Layer Setup"""

# def accuracy(outputs, labels):
#     _, preds = torch.max(outputs, dim=1)
#     return torch.tensor(torch.sum(preds == labels).item() / len(preds))

###############################################################################
import json
import torch

opl_ratio = 1.0
opl_gamma = 0.5

class OrthogonalProjectionLoss(nn.Module):
    def __init__(self, no_norm=False, weights_path=None, use_attention=False, gamma=2):
        super(OrthogonalProjectionLoss, self).__init__()
        self.weights_dict = None
        self.no_norm = no_norm
        self.gamma = gamma
        self.use_attention = use_attention
        if weights_path is not None:
            self.weights_dict = json.load(open(weights_path, "r"))

    def forward(self, features, labels=None):
        device = (torch.device('cuda') if features.is_cuda else torch.device('cpu'))

        if self.use_attention:
            features_weights = torch.matmul(features, features.T)
            features_weights = F.softmax(features_weights, dim=1)
            features = torch.matmul(features_weights, features)

        #  features are normalized
        if not self.no_norm:
            features = F.normalize(features, p=2, dim=1)

        labels = labels[:, None]  # extend dim
        mask = torch.eq(labels, labels.t()).bool().to(device)
        eye = torch.eye(mask.shape[0], mask.shape[1]).bool().to(device)

        mask_pos = mask.masked_fill(eye, 0).float()
        mask_neg = (~mask).float()
        dot_prod = torch.matmul(features, features.t())

        pos_pairs_mean = (mask_pos * dot_prod).sum() / (mask_pos.sum() + 1e-6)
        neg_pairs_mean = torch.abs(mask_neg * dot_prod).sum() / (mask_neg.sum() + 1e-6)

        loss = (1.0 - pos_pairs_mean) + (self.gamma * neg_pairs_mean)
        # loss = neg_pairs_mean

        return loss, pos_pairs_mean, neg_pairs_mean
    
    
aux_loss = OrthogonalProjectionLoss(no_norm=False, use_attention=False, gamma=opl_gamma)
###############################################################################

class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch 
        penul_Feat, out = self(images)
        op_loss, s, d = aux_loss(penul_Feat, labels) # OPL Loss
        CE_loss = F.cross_entropy(out, labels) # Calculate CE loss
        loss = CE_loss + opl_ratio * op_loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        penul_Feat, out = self(images)
        op_loss, s, d = aux_loss(penul_Feat, labels) # OPL Loss
        CE_loss = F.cross_entropy(out, labels) # Calculate CE loss
        loss = CE_loss + opl_ratio * op_loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_acc']))
        
def conv_block(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 
              nn.BatchNorm2d(out_channels), 
              nn.ReLU(inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

dim_feat = 1024
class ResNet9(ImageClassificationBase):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        
        self.conv1 = conv_block(in_channels, 64)
        self.conv2 = conv_block(64, 128, pool=True) 
        self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128)) 
        
        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True) 
        self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512)) 
        self.conv5 = conv_block(512, dim_feat , pool=True) 
        self.res3 = nn.Sequential(conv_block(dim_feat , dim_feat ), conv_block(dim_feat , dim_feat ))  
        
        self.feat = nn.Sequential(nn.MaxPool2d(2), # 1028 x 1 x 1
                                        nn.Flatten())
        
        self.classifier =   nn.Linear(dim_feat, num_classes)
        
    def forward(self, xb,  get_feat=True):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.conv5(out)
        out = self.res3(out) + out
        features = self.feat(out)
        if get_feat:
            return features, self.classifier(features)
        out = self.classifier(features)
        return out
    
    
    
###################
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1, 5)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def validate(val_loader, model):
    """
    Run evaluation
    """
    batch_time = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()

            # compute output
            _,output = model(input_var)

            output = output.float()

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target)
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

    print(' * Prec@1 {top1.avg:.3f} * Prec@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return top1.avg, top5.avg


    
def get_features( net, trainloader, testloader, feat_dir):
   
    
    if not os.path.exists(feat_dir):
        os.makedirs(feat_dir)
        
    ftrain, ftrain_labels = obtain_feature_from_loader(net, trainloader, embedding_dim = dim_feat )
 
    with open(f'{feat_dir}/train_feat.npy', 'wb') as f:
        np.save(f, ftrain)
    with open(f'{feat_dir}/train_lab.npy', 'wb') as f:
        np.save(f,  ftrain_labels)
        
        
    ftest, ftest_labels = obtain_feature_from_loader(net, testloader, embedding_dim = dim_feat )
    
    with open(f'{feat_dir}/test_feat.npy', 'wb') as f:
        np.save(f, ftest)
    with open(f'{feat_dir}/test_lab.npy', 'wb') as f:
        np.save(f, ftest_labels)      
        
        
    return ftrain, ftrain_labels, ftest, ftest_labels

 


    ####################################### 
def obtain_feature_from_loader(net, loader, embedding_dim = None):
    net.eval()
    out_features = torch.zeros((0, embedding_dim), device = 'cuda')
    out_labels = torch.zeros((0,), device = 'cuda')
    with torch.no_grad():
        for data, target in loader:
            data, target = data.cuda(), target.cuda()
            out_feature,_ = net.forward(data) 
            out_features = torch.cat((out_features,out_feature), dim = 0)
            out_labels = torch.cat((out_labels,target), dim = 0)
    return out_features.cpu().numpy(), out_labels.cpu().numpy()





net = to_device(ResNet9(3, 100), device)
net


def set_up(ckpt): 
 
    pretrained_dict= torch.load(ckpt,  map_location='cpu')
    # pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
    net.load_state_dict(pretrained_dict)
    net.eval()
    return net


eval_test = True
def main():
    ckpt = "OPL_BASE_1024D_model.h5"
    net = set_up(ckpt)
       
    path2save = "1024_DIM/OPL_features"
    ftrain, ftrain_labels, ftest, ftest_labels = get_features(net, trainloader, testloader, path2save)
    
    
###############################################################################
    opl_features, opl_labels = ftrain, ftrain_labels
    opl_labels = opl_labels[:, None]
    
    opl_features_labels = np.hstack((opl_features, opl_labels))
    
    # opl_features_mean_withClassLabel = [(np.mean(opl_features_labels[opl_features_labels[:,-1]==k][:,:-1], axis = 0)[np.newaxis, :], k) for k in np.unique(opl_features_labels[:,-1])]
    opl_features_mean = [np.mean(opl_features_labels[opl_features_labels[:,-1]==k][:,:-1], axis = 0)[np.newaxis, :] for k in np.unique(opl_features_labels[:,-1])]
    
    
    opl_features_mean  = np.concatenate( opl_features_mean, axis=0 )
    
    opl_features_mean /= np.linalg.norm(opl_features_mean, axis=1)[:, np.newaxis]
    
    results = np.dot(opl_features_mean , opl_features_mean .T)
    
    results = np.triu(results, k=1)
    
    print(np.mean(np.abs(results[results != 0])))
###############################################################################

    if eval_test:
       acc1, acc5 = validate(testloader, net)

###############################################################################
    # ftrain, ftrain_labels, ftest, ftest_labels = get_features_edit(args, net, trainloader, testloader)
if __name__ == '__main__':
    main()
