##################################
# Acknowledgment:
# Part of this code is adopted from https://github.com/kahnchana/opl
# Part of this code is adopted from https://www.kaggle.com/code/yiweiwangau/cifar-100-resnet-pytorch-75-17-accuracy
# Part of this code is adopted from https://github.com/deeplearning-wisc/cider
##################################

import argparse
import math
import os
import time
from datetime import datetime
import logging
import pprint

import torch
import torch.nn.parallel
import torch.nn.functional as F
import torch.optim
import torch.utils.data
import numpy as np


####################
# Commented out IPython magic to ensure Python compatibility.
import pandas as pd
import os
import torch
import time
import torchvision
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torch.utils.data import random_split
from torchvision.utils import make_grid
import torchvision.models as models
import matplotlib.pyplot as plt
from sklearn.metrics import *
####################

 

 
import faiss
from tqdm import tqdm

 
from utils.detection_util import set_ood_loader_ImageNet, obtain_feature_from_loader, set_ood_loader_small, get_and_print_results
from utils.util import set_loader_ImageNet, set_loader_small, set_model
from utils.display_results import  plot_distribution, print_measures, save_as_dataframe

from utils import (CompLoss, DisLoss, DisLPLoss, SupConLoss, 
                 adjust_learning_rate, warmup_learning_rate, 
                set_loader_small, set_loader_ImageNet)

############
import warnings
warnings.filterwarnings("ignore")
############
parser = argparse.ArgumentParser(description='Training with CIDER and SupCon Loss')
parser.add_argument('--gpu', default=7, type=int, help='which GPU to use')
parser.add_argument('--seed', default=4, type=int, help='random seed')
parser.add_argument('--w', default=2, type=float,
                    help='loss scale')
parser.add_argument('--proto_m', default= 0.99, type=float,
                   help='weight of prototype update')
parser.add_argument('--feat_dim', default = 128, type=int,
                    help='feature dim')
parser.add_argument('--in-dataset', default="CIFAR-100", type=str, help='in-distribution dataset')
parser.add_argument('--id_loc', default="datasets/CIFAR100", type=str, help='location of in-distribution dataset')
parser.add_argument('--model', default='resnet18', type=str, help='model architecture: [resnet18, wrt40, wrt28, densenet100]')
parser.add_argument('--head', default='mlp', type=str, help='either mlp or linear head')
parser.add_argument('--loss', default = 'cider', type=str, choices = ['supcon', 'cider'],
                    help='name of experiment')
parser.add_argument('--epochs', default=500, type=int,
                    help='number of total epochs to run')
parser.add_argument('--trial', type=str, default='0',
                        help='id for recording multiple runs')
parser.add_argument('--save-epoch', default=100, type=int,
                    help='save the model every save_epoch')
parser.add_argument('--start-epoch', default=0, type=int,
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default= 512, type=int,
                    help='mini-batch size (default: 64)')
parser.add_argument('--learning_rate', default=0.5, type=float,
                    help='initial learning rate')
# if linear lr schedule
parser.add_argument('--lr_decay_epochs', type=str, default='100,150,180',
                        help='where to decay lr, can be a list')
parser.add_argument('--lr_decay_rate', type=float, default=0.1,
                        help='decay rate for learning rate')
# if cosine lr schedule
parser.add_argument('--cosine', action='store_true',
                        help='using cosine annealing')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    help='weight decay (default: 0.0001)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
                    help='print frequency (default: 10)')
parser.add_argument('--temp', type=float, default=0.1,
                        help='temperature for loss function')
parser.add_argument('--warm', action='store_true',
                        help='warm-up for large batch training')
parser.add_argument('--normalize', action='store_true',
                        help='normalize feat embeddings')
parser.set_defaults(bottleneck=True)
parser.set_defaults(augment=True)

########################
parser = argparse.ArgumentParser(description='Evaluates OOD Detector',formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--in_dataset', default="CIFAR-100", type=str, help='in-distribution dataset') 
parser.add_argument('-b', '--batch-size', default=512, type=int, help='mini-batch size')
parser.add_argument('--epoch', default ="500", type=str, help='which epoch to test')
parser.add_argument('--gpu', default=0,  type=int, help='which GPU to use')
parser.add_argument('--loss', default = 'cider', type=str, choices = ['supcon', 'cider'],
                help='loss of experiment')
parser.add_argument('--name', type=str, default = 'ckpt_c100')
parser.add_argument('--id_loc', default="datasets/CIFAR100", type=str, help='location of in-distribution dataset')
parser.add_argument('--ood_loc', default="datasets/small_OOD_dataset", type=str, help='location of ood datasets')

parser.add_argument('--score', default='knn', type=str, help='score options: knn|maha|msp|odin|energy')
parser.add_argument('--K', default=300, type=int, help='K in KNN score')
parser.add_argument('--subset', default=False, type=bool, help='whether to use subset for KNN')
# parser.add_argument('--norm_pe', type = bool, default = True, help='if normalize penultimate layer')
parser.add_argument('--multiplier', default=1, type=float,
                  help='norm multipler to help solve numerical issues with precision matrix')
parser.add_argument('--model', default='resnet34', type=str, help='model architecture')
parser.add_argument('--embedding_dim', default = 512, type=int, help='encoder feature dim')
parser.add_argument('--feat_dim', default = 128, type=int, help='head feature dim')
parser.add_argument('--head', default='mlp', type=str, help='either mlp or linear head')
parser.add_argument('--normalize', action='store_true',
                    help='normalize feat embeddings')
parser.add_argument('--out_as_pos', action='store_true', help='if OOD data defined as positive class.')
parser.add_argument('--T', default=1000, type=float, help='temperature: energy|Odin')

    
args = parser.parse_args()


###############################################################################


args.gpu = 0
args.proto_m = 0.95
args.feat_dim = 1024
args.n_cls = 100

###############################################################################

def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(42)

###############################################################################

# %matplotlib inline

batch_size = 400
epochs = 120
max_lr = 0.001
grad_clip = 0.01
weight_decay =0.001
opt_func = torch.optim.Adam
opl_ratio = 1.0

train_data = torchvision.datasets.CIFAR100('./', train=True, download=True)

# Stick all the images together to form a 1600000 X 32 X 3 array
x = np.concatenate([np.asarray(train_data[i][0]) for i in range(len(train_data))])

# calculate the mean and std along the (0, 1) axes
mean = np.mean(x, axis=(0, 1))/255
std = np.std(x, axis=(0, 1))/255
# the the mean and std
mean=mean.tolist()
std=std.tolist()

transform_train = tt.Compose([tt.RandomCrop(32, padding=4,padding_mode='reflect'), 
                         tt.RandomHorizontalFlip(), 
                         tt.ToTensor(), 
                         tt.Normalize(mean,std,inplace=True)])
transform_test = tt.Compose([tt.ToTensor(), tt.Normalize(mean,std)])

trainset = torchvision.datasets.CIFAR100("./",
                                         train=True,
                                         download=True,
                                         transform=transform_train)
train_loader = torch.utils.data.DataLoader(
    trainset, batch_size, shuffle=True, num_workers=2,pin_memory=True)

testset = torchvision.datasets.CIFAR100("./",
                                        train=False,
                                        download=True,
                                        transform=transform_test)
val_loader = torch.utils.data.DataLoader(
    testset, batch_size*2,pin_memory=True, num_workers=2)

"""# Device check and load model into device"""

def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

###############################################################################
 
device = get_default_device()
device

train_loader = DeviceDataLoader(train_loader, device)
val_loader = DeviceDataLoader(val_loader, device)

"""# Layer Setup"""

# def accuracy(outputs, labels):
#     _, preds = torch.max(outputs, dim=1)
#     return torch.tensor(torch.sum(preds == labels).item() / len(preds))

class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch 
        # input = torch.cat([input[0], input[1]], dim=0).cuda()
        # target = target.repeat(2).cuda()
        features, penultimate_feat, out = self(images)
###########
        dis_loss = criterion_dis(features, labels) # V2: EMA style
        comp_loss = criterion_comp(features, criterion_dis.prototypes, labels)
        CE_loss = F.cross_entropy(out, labels) # Calculate CE loss
        cider_loss = args.w * comp_loss + dis_loss
        loss = CE_loss + opl_ratio * cider_loss 
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        features, penultimate_feat, out = self(images)
###########
        dis_loss = criterion_dis(features,labels) # V2: EMA style
        comp_loss = criterion_comp(features, criterion_dis.prototypes, labels)
        CE_loss = F.cross_entropy(out, labels) # Calculate CE loss
        cider_loss = args.w * comp_loss + dis_loss
        loss = CE_loss + opl_ratio * cider_loss 
        acc = accuracy(out, labels)  # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_acc']))
        
def conv_block(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 
              nn.BatchNorm2d(out_channels), 
              nn.ReLU(inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)


dim_feat = 512
class ResNet9(ImageClassificationBase):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        
        self.conv1 = conv_block(in_channels, 64)
        self.conv2 = conv_block(64, 128, pool=True) 
        self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128)) 
        
        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True) 
        self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512)) 
        self.conv5 = conv_block(512, dim_feat , pool=True) 
        self.res3 = nn.Sequential(conv_block(dim_feat , dim_feat ), conv_block(dim_feat , dim_feat ))  
        
        self.feat = nn.Sequential(nn.MaxPool2d(2), # 1028 x 1 x 1
                                        nn.Flatten())
 
 
        self.head = nn.Sequential(
            nn.Linear(dim_feat, dim_feat),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feat, args.feat_dim)
        ) 
        self.classifier =   nn.Linear(args.feat_dim, num_classes)
        
    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.conv5(out)
        out = self.res3(out) + out
        features = self.feat(out).squeeze()
        
        penul_feat = F.normalize(features, dim=1)
        
        unnorm_features = self.head(penul_feat)
        features= F.normalize(unnorm_features, dim=1)
     
        
        return features, penul_feat, self.classifier(features)
 
    
###################
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1, 5)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def validate(val_loader, model):
    """
    Run evaluation
    """
    batch_time = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()

            # compute output
            _,_,output = model(input_var)

            output = output.float()

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target)
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

    print(' * Prec@1 {top1.avg:.3f} * Prec@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return top1.avg, top5.avg


    
def get_features( net, trainloader, testloader, feat_dir):
   
    
    if not os.path.exists(feat_dir):
        os.makedirs(feat_dir)
        
    ftrain, ftrain_labels = obtain_feature_from_loader(net, trainloader, embedding_dim = args.feat_dim  )
 
    with open(f'{feat_dir}/train_feat.npy', 'wb') as f:
        np.save(f, ftrain)
    with open(f'{feat_dir}/train_lab.npy', 'wb') as f:
        np.save(f,  ftrain_labels)
        
        
    ftest, ftest_labels = obtain_feature_from_loader(net, testloader, embedding_dim = args.feat_dim  )
    
    with open(f'{feat_dir}/test_feat.npy', 'wb') as f:
        np.save(f, ftest)
    with open(f'{feat_dir}/test_lab.npy', 'wb') as f:
        np.save(f, ftest_labels)      
        
        
    return ftrain, ftrain_labels, ftest, ftest_labels

 


    ####################################### 
def obtain_feature_from_loader(net, loader, embedding_dim = None):
    net.eval()
    out_features = torch.zeros((0, embedding_dim), device = 'cuda')
    out_labels = torch.zeros((0,), device = 'cuda')
    with torch.no_grad():
        for data, target in loader:
            data, target = data.cuda(), target.cuda()
            out_feature, _, _ = net.forward(data)  
            out_features = torch.cat((out_features,out_feature), dim = 0)
            out_labels = torch.cat((out_labels,target), dim = 0)
    return out_features.cpu().numpy(), out_labels.cpu().numpy()





net = to_device(ResNet9(3, 100), device)
net

###############################################################################

#CIDER Losses

criterion_supcon = SupConLoss(temperature=args.temp).cuda()
criterion_comp = CompLoss(args, temperature=args.temp).cuda()
# V1: learnable prototypes
# criterion_dis = DisLPLoss(args, model, val_loader, temperature=args.temp).cuda() # V1: learnable prototypes
# optimizer = torch.optim.SGD([ {"params": model.parameters()},
#                               {"params": criterion_dis.prototypes}  
#                             ], lr = args.learning_rate,
#                             momentum=args.momentum,
#                             nesterov=True,
#                             weight_decay=args.weight_decay)

# V2: EMA style prototypes
criterion_dis = DisLoss(args, net, val_loader, temperature=args.temp).cuda() # V2: prototypes with EMA style update


###############################################################################

def set_up(ckpt): 
    train_loader, test_loader = set_loader_small(args, eval = True)
    pretrained_dict= torch.load(ckpt,  map_location='cpu')
    # pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
    net.load_state_dict(pretrained_dict)
    net.eval()
    return net


eval_test = True
def main():
    ckpt = "CIDER_BASE_1024D_model.h5"
    net = set_up(ckpt)
       
    path2save = "CIDER_features"
    ftrain, ftrain_labels, ftest, ftest_labels = get_features(net, train_loader, val_loader, path2save)
    
    
###############################################################################
    opl_features, opl_labels = ftrain, ftrain_labels
    opl_labels = opl_labels[:, None]
    
    opl_features_labels = np.hstack((opl_features, opl_labels))
    
    # opl_features_mean_withClassLabel = [(np.mean(opl_features_labels[opl_features_labels[:,-1]==k][:,:-1], axis = 0)[np.newaxis, :], k) for k in np.unique(opl_features_labels[:,-1])]
    opl_features_mean = [np.mean(opl_features_labels[opl_features_labels[:,-1]==k][:,:-1], axis = 0)[np.newaxis, :] for k in np.unique(opl_features_labels[:,-1])]
    
    
    opl_features_mean  = np.concatenate( opl_features_mean, axis=0 )
    
    opl_features_mean /= np.linalg.norm(opl_features_mean, axis=1)[:, np.newaxis]
    
    results = np.dot(opl_features_mean , opl_features_mean .T)
    
    results = np.triu(results, k=1)
    
    # print(np.mean(np.abs(results[results != 0])))
    print(np.mean((results[results != 0])))
###############################################################################

    if eval_test:
       acc1, acc5 = validate(val_loader, net)

###############################################################################
    # ftrain, ftrain_labels, ftest, ftest_labels = get_features_edit(args, net, trainloader, testloader)
if __name__ == '__main__':
    main()