##################################
# Acknowledgment:
# Part of this code is adopted from https://github.com/kahnchana/opl
# Part of this code is adopted from https://www.kaggle.com/code/yiweiwangau/cifar-100-resnet-pytorch-75-17-accuracy
##################################

# Commented out IPython magic to ensure Python compatibility.
import pandas as pd
import os
import torch
import time
import torchvision
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torch.utils.data import random_split
from torchvision.utils import make_grid
import torchvision.models as models
import matplotlib.pyplot as plt
from sklearn.metrics import *

#########################################

def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(42)

#########################################

# %matplotlib inline

batch_size = 400
epochs = 120
max_lr = 0.001
grad_clip = 0.01
weight_decay =0.001
opt_func = torch.optim.Adam


"""# Device check and load model into device"""

def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

###############################################################################
import json
import torch

loss_function = nn.CrossEntropyLoss()

opl_ratio = 1.0
opl_gamma = 0.5

class OrthogonalProjectionLoss(nn.Module):
    def __init__(self, no_norm=False, weights_path=None, use_attention=False, gamma=2):
        super(OrthogonalProjectionLoss, self).__init__()
        self.weights_dict = None
        self.no_norm = no_norm
        self.gamma = gamma
        self.use_attention = use_attention
        if weights_path is not None:
            self.weights_dict = json.load(open(weights_path, "r"))

    def forward(self, features, labels=None):
        device = (torch.device('cuda') if features.is_cuda else torch.device('cpu'))

        if self.use_attention:
            features_weights = torch.matmul(features, features.T)
            features_weights = F.softmax(features_weights, dim=1)
            features = torch.matmul(features_weights, features)

        #  features are normalized
        if not self.no_norm:
            features = F.normalize(features, p=2, dim=1)

        labels = labels[:, None]  # extend dim
        mask = torch.eq(labels, labels.t()).bool().to(device)
        eye = torch.eye(mask.shape[0], mask.shape[1]).bool().to(device)

        mask_pos = mask.masked_fill(eye, 0).float()
        mask_neg = (~mask).float()
        dot_prod = torch.matmul(features, features.t())

        pos_pairs_mean = (mask_pos * dot_prod).sum() / (mask_pos.sum() + 1e-6)
        neg_pairs_mean = torch.abs(mask_neg * dot_prod).sum() / (mask_neg.sum() + 1e-6)

        loss = (1.0 - pos_pairs_mean) + (self.gamma * neg_pairs_mean)
        # loss = neg_pairs_mean

        return loss, pos_pairs_mean, neg_pairs_mean
    
    
aux_loss = OrthogonalProjectionLoss(no_norm=False, use_attention=False, gamma=opl_gamma)
###############################################################################

dtype = 'torch.cuda.FloatTensor'
device = get_default_device()
device

def accuracy(output, target, topk=(1, 5)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def validate(model):
    """
    Run evaluation
    """
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        # for i, (input, target) in enumerate(val_loader):
        #     target = target.cuda()
        #     input_var = input.cuda()
        #     target_var = target.cuda()

        for batch_index in range(0, len(opl_features_test), batch_size):
            features = opl_features_test[batch_index:batch_index+batch_size]
            features  =  torch.from_numpy(  features ).type(dtype).to(device)
            
            labels = opl_labels_test[batch_index:batch_index+batch_size]
            labels =  torch.from_numpy(  labels ).type(dtype).to(device)
        
 
            outputs = model(features)
            loss = loss_function(outputs, labels.long())
            
            
            outputs = outputs.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1, prec5 = accuracy(outputs.data, labels.long())
            losses.update(loss.item(), len(features))
            top1.update(prec1.item(), len(features))
            top5.update(prec5.item(), len(features))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

    print(' * Prec@1 {top1.avg:.3f} * Prec@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return top1.avg, top5.avg, losses.avg

def epoch_end(epoch, result):
    print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
        epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_acc']))
    
    
###############
with open(  "train_ICR_feat.npy"  , 'rb') as f: # This could also be the training features from DCR
    opl_features_train  = np.load(f)

with open( "train_ICR_lab.npy" , 'rb') as f:   # This could also be the training Labels from DCR
    opl_labels_train   = np.load(f)
    
with open( "test_ICR_feat.npy"   , 'rb') as f: # This could also be the Test features from DCR
    opl_features_test  = np.load(f)
 
with open(  "test__ICR_lab.npy"   , 'rb') as f: # This could also be the Test Labels from DCR
    opl_labels_test  = np.load(f)
    

def set_up(ckpt, net): 
    pretrained_dict= torch.load(ckpt,  map_location='cuda:0')
    # pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
    # net.load_state_dict(pretrained_dict)
    net.classifier.weight = nn.Parameter(pretrained_dict['classifier.weight'])
    net.classifier.bias = nn.Parameter(pretrained_dict['classifier.bias'])
    # classifier.fc.weight = net.output.weight
    # classifier.fc.bias = net.output.bias
    net.eval()
    return net 
###############

class LinearHead(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        
        self.classifier =   nn.Linear(in_features, out_features)
        
    def forward(self, xb):
        out = self.classifier(xb)
        return out


in_feat = 1024
out_feat = 100
model = LinearHead(in_feat, out_feat)
model = to_device(model, device)
model


ckpt = "OPL_BASE_1024D_model.h5"
net = set_up(ckpt, model)
    
    
    
"""# Training Setup"""

@torch.no_grad()
def evaluate(model):
    outputs =  validate(model)
    return{'val_loss': outputs[-1], 'val_acc': outputs[0]}

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit_one_cycle(epochs, max_lr, model, 
                  weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD):
    
    torch.cuda.empty_cache()
    history = []
    
    # Set up cutom optimizer with weight decay
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)
    # Set up one-cycle learning rate scheduler
    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, 
                                                steps_per_epoch=len(opl_features_train))
    
 
    # start = time.time()
    # net.eval()
    # classifier.train()
    
    # opl_features_train, opl_labels_train = get_features(net, training_loader,  args.feat_dim, fname = "train_ISR")
    
    for epoch in range(epochs):
        # Training Phase 
        model.train()
        train_losses = []
        lrs = []
        
        for batch_index in range(0, len(opl_features_train), batch_size ):
            features = opl_features_train[batch_index:batch_index+batch_size ]
            features  =  torch.from_numpy(  features ).type(dtype).to(device)
            
            labels = opl_labels_train[batch_index:batch_index+batch_size ]
            labels =  torch.from_numpy(  labels ).type(dtype).to(device)
            
            outputs = model(features)
            base_loss = loss_function(outputs, labels.long())
            op_loss, s, d = aux_loss(features, labels.long())
            loss = base_loss + opl_ratio * op_loss
            
            
            train_losses.append(loss)
            loss.backward()
            
            
            # Gradient clipping
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            
            optimizer.step()
            optimizer.zero_grad()
            
            
            # Record & update learning rate
            lrs.append(get_lr(optimizer))
            sched.step()
        
        # Validation phase
        result = evaluate(model)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        epoch_end(epoch, result)
        history.append(result)
    return history
     

# Initial evaluation
history = [evaluate(model)]
history

# Fitting the first 1/4 epochs
current_time=time.time()


history += fit_one_cycle(int(epochs/4), max_lr, model, 
                              grad_clip=grad_clip, 
                              weight_decay=weight_decay, 
                              opt_func=opt_func)

# Fitting the second 1/4 epochs
history += fit_one_cycle(int(epochs/4), max_lr/10, model, 
                              grad_clip=grad_clip, 
                              weight_decay=weight_decay, 
                              opt_func=opt_func)

history += fit_one_cycle(int(epochs/8), max_lr/100, model,  
                              grad_clip=grad_clip, 
                              weight_decay=weight_decay, 
                              opt_func=opt_func)

history += fit_one_cycle(int(epochs/8), max_lr/1000, model, 
                              grad_clip=grad_clip, 
                              weight_decay=weight_decay, 
                              opt_func=opt_func)

history += fit_one_cycle(int(epochs/4), max_lr/100, model,  
                              grad_clip=grad_clip, 
                              weight_decay=weight_decay, 
                              opt_func=opt_func)