##################################
# Acknowledgment:
# Part of this code is adopted from https://github.com/kahnchana/opl
# Part of this code is adopted from https://www.kaggle.com/code/yiweiwangau/cifar-100-resnet-pytorch-75-17-accuracy
# Part of this code is adopted from https://github.com/deeplearning-wisc/cider
##################################


import argparse
import math
import os
import time
from datetime import datetime
import logging
 
import pprint

import torch
import torch.nn.parallel
import torch.nn.functional as F
import torch.optim
import torch.utils.data
import numpy as np


####################
# Commented out IPython magic to ensure Python compatibility.
import pandas as pd
import os
import torch
import time
import torchvision
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torch.utils.data import random_split
from torchvision.utils import make_grid
import torchvision.models as models
import matplotlib.pyplot as plt
from sklearn.metrics import *
####################


from utils import (CompLoss, DisLoss, DisLPLoss, SupConLoss, 
                AverageMeter, adjust_learning_rate, warmup_learning_rate, 
                set_loader_small, set_loader_ImageNet)

############
import warnings
warnings.filterwarnings("ignore")
############
parser = argparse.ArgumentParser(description='Training with CIDER and SupCon Loss')
parser.add_argument('--gpu', default=7, type=int, help='which GPU to use')
parser.add_argument('--seed', default=4, type=int, help='random seed')
parser.add_argument('--w', default=2, type=float,
                    help='loss scale')
parser.add_argument('--proto_m', default= 0.99, type=float,
                   help='weight of prototype update')
parser.add_argument('--feat_dim', default = 128, type=int,
                    help='feature dim')
parser.add_argument('--in-dataset', default="CIFAR-100", type=str, help='in-distribution dataset')
parser.add_argument('--id_loc', default="datasets/CIFAR100", type=str, help='location of in-distribution dataset')
parser.add_argument('--model', default='resnet18', type=str, help='model architecture: [resnet18, wrt40, wrt28, densenet100]')
parser.add_argument('--head', default='mlp', type=str, help='either mlp or linear head')
parser.add_argument('--loss', default = 'cider', type=str, choices = ['supcon', 'cider'],
                    help='name of experiment')
parser.add_argument('--epochs', default=500, type=int,
                    help='number of total epochs to run')
parser.add_argument('--trial', type=str, default='0',
                        help='id for recording multiple runs')
parser.add_argument('--save-epoch', default=100, type=int,
                    help='save the model every save_epoch')
parser.add_argument('--start-epoch', default=0, type=int,
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default= 512, type=int,
                    help='mini-batch size (default: 64)')
parser.add_argument('--learning_rate', default=0.5, type=float,
                    help='initial learning rate')
# if linear lr schedule
parser.add_argument('--lr_decay_epochs', type=str, default='100,150,180',
                        help='where to decay lr, can be a list')
parser.add_argument('--lr_decay_rate', type=float, default=0.1,
                        help='decay rate for learning rate')
# if cosine lr schedule
parser.add_argument('--cosine', action='store_true',
                        help='using cosine annealing')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    help='weight decay (default: 0.0001)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
                    help='print frequency (default: 10)')
parser.add_argument('--temp', type=float, default=0.1,
                        help='temperature for loss function')
parser.add_argument('--warm', action='store_true',
                        help='warm-up for large batch training')
parser.add_argument('--normalize', action='store_true',
                        help='normalize feat embeddings')
parser.set_defaults(bottleneck=True)
parser.set_defaults(augment=True)

args = parser.parse_args()


###############################################################################


args.gpu = 0
args.proto_m = 0.95
args.feat_dim = 1024
args.n_cls = 100

###############################################################################

def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(42)

###############################################################################

# %matplotlib inline

batch_size = 400
epochs = 120
max_lr = 0.001
grad_clip = 0.01
weight_decay =0.001
opt_func = torch.optim.Adam
opl_ratio = 1.0

train_data = torchvision.datasets.CIFAR100('./', train=True, download=True)

# Stick all the images together to form a 1600000 X 32 X 3 array
x = np.concatenate([np.asarray(train_data[i][0]) for i in range(len(train_data))])

# calculate the mean and std along the (0, 1) axes
mean = np.mean(x, axis=(0, 1))/255
std = np.std(x, axis=(0, 1))/255
# the the mean and std
mean=mean.tolist()
std=std.tolist()

transform_train = tt.Compose([tt.RandomCrop(32, padding=4,padding_mode='reflect'), 
                         tt.RandomHorizontalFlip(), 
                         tt.ToTensor(), 
                         tt.Normalize(mean,std,inplace=True)])
transform_test = tt.Compose([tt.ToTensor(), tt.Normalize(mean,std)])

trainset = torchvision.datasets.CIFAR100("./",
                                         train=True,
                                         download=True,
                                         transform=transform_train)
train_loader = torch.utils.data.DataLoader(
    trainset, batch_size, shuffle=True, num_workers=2,pin_memory=True)

testset = torchvision.datasets.CIFAR100("./",
                                        train=False,
                                        download=True,
                                        transform=transform_test)
val_loader = torch.utils.data.DataLoader(
    testset, batch_size*2,pin_memory=True, num_workers=2)

"""# Device check and load model into device"""

def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

###############################################################################
 
device = get_default_device()
device

train_loader = DeviceDataLoader(train_loader, device)
val_loader = DeviceDataLoader(val_loader, device)

"""# Layer Setup"""

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch 
        # input = torch.cat([input[0], input[1]], dim=0).cuda()
        # target = target.repeat(2).cuda()
        features, penultimate_feat, out = self(images)
###########
        dis_loss = criterion_dis(features, labels) # V2: EMA style
        comp_loss = criterion_comp(features, criterion_dis.prototypes, labels)
        CE_loss = F.cross_entropy(out, labels) # Calculate CE loss
        cider_loss = args.w * comp_loss + dis_loss
        loss = CE_loss + opl_ratio * cider_loss 
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        features, penultimate_feat, out = self(images)
###########
        dis_loss = criterion_dis(features,labels) # V2: EMA style
        comp_loss = criterion_comp(features, criterion_dis.prototypes, labels)
        CE_loss = F.cross_entropy(out, labels) # Calculate CE loss
        cider_loss = args.w * comp_loss + dis_loss
        loss = CE_loss + opl_ratio * cider_loss 
        acc = accuracy(out, labels)  # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_acc']))
        
def conv_block(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 
              nn.BatchNorm2d(out_channels), 
              nn.ReLU(inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)


dim_feat = 512
class ResNet9(ImageClassificationBase):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        
        self.conv1 = conv_block(in_channels, 64)
        self.conv2 = conv_block(64, 128, pool=True) 
        self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128)) 
        
        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True) 
        self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512)) 
        self.conv5 = conv_block(512, dim_feat , pool=True) 
        self.res3 = nn.Sequential(conv_block(dim_feat , dim_feat ), conv_block(dim_feat , dim_feat ))  
        
        self.feat = nn.Sequential(nn.MaxPool2d(2), # 1028 x 1 x 1
                                        nn.Flatten())
 
 
        self.head = nn.Sequential(
            nn.Linear(dim_feat, dim_feat),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feat, args.feat_dim)
        ) 
        self.classifier =   nn.Linear(args.feat_dim, num_classes)
        
    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.conv5(out)
        out = self.res3(out) + out
        features = self.feat(out).squeeze()
        
        penul_feat = F.normalize(features, dim=1)
        
        unnorm_features = self.head(penul_feat)
        features= F.normalize(unnorm_features, dim=1)
     
        
        return features, penul_feat, self.classifier(features)
 
 
 
    
model = to_device(ResNet9(3, 100), device)
model

"""# Training Setup"""

@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']



###############################################################################

#CIDER Losses

criterion_supcon = SupConLoss(temperature=args.temp).cuda()
criterion_comp = CompLoss(args, temperature=args.temp).cuda()
# V1: learnable prototypes
# criterion_dis = DisLPLoss(args, model, val_loader, temperature=args.temp).cuda() # V1: learnable prototypes
# optimizer = torch.optim.SGD([ {"params": model.parameters()},
#                               {"params": criterion_dis.prototypes}  
#                             ], lr = args.learning_rate,
#                             momentum=args.momentum,
#                             nesterov=True,
#                             weight_decay=args.weight_decay)

# V2: EMA style prototypes
criterion_dis = DisLoss(args, model, val_loader, temperature=args.temp).cuda() # V2: prototypes with EMA style update


###############################################################################
def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader, 
                  weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD):
    torch.cuda.empty_cache()
    history = []
    
    # Set up cutom optimizer with weight decay
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)
    # Set up one-cycle learning rate scheduler
    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, 
                                                steps_per_epoch=len(train_loader))
    
    for epoch in range(epochs):
        # Training Phase 
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = model.training_step(batch)
            train_losses.append(loss)
            loss.backward()
            
            # Gradient clipping
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            
            optimizer.step()
            optimizer.zero_grad()
            
            # Record & update learning rate
            lrs.append(get_lr(optimizer))
            sched.step()
        
        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        model.epoch_end(epoch, result)
        history.append(result)
    return history

# Initial evaluation
 
history = [evaluate(model, val_loader)]
history

# Fitting the first 1/4 epochs
current_time=time.time()

history += fit_one_cycle(int(epochs/4), max_lr, model, train_loader, val_loader, 
                              grad_clip=grad_clip, 
                              weight_decay=weight_decay, 
                              opt_func=opt_func)

# Fitting the second 1/4 epochs
history += fit_one_cycle(int(epochs/4), max_lr/10, model, train_loader, val_loader, 
                              grad_clip=grad_clip, 
                              weight_decay=weight_decay, 
                              opt_func=opt_func)

history += fit_one_cycle(int(epochs/8), max_lr/100, model, train_loader, val_loader, 
                              grad_clip=grad_clip, 
                              weight_decay=weight_decay, 
                              opt_func=opt_func)

history += fit_one_cycle(int(epochs/8), max_lr/1000, model, train_loader, val_loader, 
                              grad_clip=grad_clip, 
                              weight_decay=weight_decay, 
                              opt_func=opt_func)

history += fit_one_cycle(int(epochs/4), max_lr/100, model, train_loader, val_loader, 
                              grad_clip=grad_clip, 
                              weight_decay=weight_decay, 
                              opt_func=opt_func)


# Print training time
print('Training time: {:.2f} s'.format(time.time() - current_time))

"""# Prediction"""

# Collect training time and result
current_time = time.time()
result = evaluate(model, val_loader)
result
print('Training time: {:.2f} s'.format(time.time() - current_time))

# Saving the model to h5 file
torch.save(model.state_dict(), 'CIDER_BASE_1024D_model.h5')

