'''Train an encoder using Contrastive Learning.'''

import argparse
import os
import subprocess
import sys
import logging

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from lars import LARS
from tqdm import tqdm

from configs import get_datasets
from critic import LinearCritic
from fare_head import FareHead
from fare_sparse import *
from compute_lsh_conditional_att import compute_lsh_conditional_attention
from evaluate import save_checkpoint, encode_train_set, train_clf, test, encode_conditional_train_set, test_conditional
from models import *
from build_auxdatasets import *
from scheduler import CosineAnnealingWithLinearRampLR
from criterion import ConditionalSamplingLoss
from threshold_annealing import thresholdAnnealing
from logger import txt_logger

parser = argparse.ArgumentParser(description='PyTorch Contrastive Learning.')
parser.add_argument('--base-lr', default=0.25, type=float, help='base learning rate, rescaled by batch_size/256')
parser.add_argument("--momentum", default=0.9, type=float, help='SGD momentum')
parser.add_argument('--resume', '-r', type=str, default='', help='resume from checkpoint with this filename')
parser.add_argument('--dataset', '-d', type=str, default='colorMNIST', help='dataset')
parser.add_argument('--temperature', type=float, default=0.5, help='InfoNCE temperature')
parser.add_argument("--batch-size", type=int, default=256, help='Training batch size')
parser.add_argument("--num-epochs", type=int, default=100, help='Number of training epochs')
parser.add_argument("--cosine-anneal", action='store_true', help="Use cosine annealing on the learning rate")
parser.add_argument("--no_color_distor", action='store_true', help="Use color distoration or not")
parser.add_argument("--arch", type=str, default='resnet50', help='Encoder architecture')
parser.add_argument("--num-workers", type=int, default=2, help='Number of threads for data loaders')
parser.add_argument("--test-freq", type=int, default=10, help='Frequency to fit a linear clf with L-BFGS for testing'
                                                              'Not appropriate for large datasets. Set 0 to avoid '
                                                              'classifier only training here.')
parser.add_argument("--filename", type=str, default='ckpt.pth', help='Output file name')
parser.add_argument("--run-name", type=str, default='new_run', help='Run file name in saved checkpoints')
parser.add_argument("--lambda_", type=float, default=0.01, help='')
parser.add_argument("--temp_z", type=float, help='fare attention softmax temperature')
parser.add_argument("--scale_z", type=float, default=1, help='')
parser.add_argument("--lowrank-approximation", type = int, default = None, help = 'low rank approximation dimension via svd ')
parser.add_argument("--kz_warmup_epoch", type=int, default=0, help='special case for Kz')
parser.add_argument("--warmup_percent", type=float, default=0.33, help='warmup percent')
parser.add_argument("--start_high_threshold", type=float, default=1., help='annealing')
parser.add_argument("--end_high_threshold", type=float, default=0.6, help='annealing')
parser.add_argument("--start_low_threshold", type=float, default=0., help='annealing')
parser.add_argument("--end_low_threshold", type=float, default=0.4, help='annealing')
parser.add_argument("--save_path", type=str, default="train_related", help="save root folder that will store the results")
parser.add_argument("--z_head", type=str, default="", help="specify z head to learn a better representation to condition on")
parser.add_argument("--save_freq", type=int, default=25)
parser.add_argument("--reg_weight", type=float, default=1e-6)
parser.add_argument("--weight_clip_threshold", type=float, default=1e-6, help="clip for computing the weight")
parser.add_argument("--lsh", type=str, default = 'default', help = 'have lsh off, default, or in reverse mode')
parser.add_argument("--lsh-adjacency", type = str, default = "adjacent", help = 'LSH scheme attends either intra-bucket, across adjacent buckets, or one back')
parser.add_argument("--attention-sim", action = 'store_true', help = 'use sample attention instead of generic cosine similarity in critic')
parser.add_argument("--z-init", type = str, default = 'default', help = 'Initialize linear projections in z-head with specific values, either normal or identity')
parser.add_argument("--projection-mode", type = str, default = 'linear', help = 'Whether to project the z-head with linear or nonlinear projection')
parser.add_argument("--attention-weightdecay", type = float, default = 1e-4, help = 'instantiate separate optimizer for attention component with its own weight decay')
parser.add_argument("--CelebA_UTKFace", type = str, help = 'Dummy argument, do not specify')


args = parser.parse_args()
args.lr = args.base_lr * (args.batch_size / 256)

# folder name
if args.CelebA_UTKFace is not None:
    args.save_folder = f"{args.save_path}/{args.condition_mode}/{args.CelebA_UTKFace}/{args.dataset}/bz_{args.batch_size}_ep_{args.num_epochs}/lsh_{args.lsh}"
else:
    args.save_folder = f"{args.save_path}/{args.dataset}/bz_{args.batch_size}_ep_{args.num_epochs}"

args.model_name = f"original_{args.arch}_lr_{args.base_lr}_cosineaneal_{args.cosine_anneal}_lambda_{args.lambda_}_lshadjacency_{args.lsh_adjacency}"

if args.no_color_distor:
    args.model_name = f"{args.model_name}_nocolordistoration"

args.model_name = f"{args.model_name}"
args.model_name = f"{args.model_name}_ztemp_{args.temp_z}_temperature_{args.temperature}"
args.model_name = f"{args.model_name}"
args.model_name = f"{args.model_name}_zinit_{args.z_init}_projectionmode_{args.projection_mode}_attweightdecay_{args.attention_weightdecay}"

args.save_location = f"{args.save_folder}/{args.model_name}"
os.makedirs(args.save_location, exist_ok=True)
# args.git_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD'])
# args.git_diff = subprocess.check_output(['git', 'diff'])

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
best_mse = 100000000000  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
clf = None

# logger
print('===> Preparing Logger...')
scalar_logger = txt_logger(args.save_location, args, 'python ' + ' '.join(sys.argv))

print('==> Preparing data..')

if args.dataset == 'colorMNIST':

    trainset, testset, clftrainset, num_classes, stem = get_datasets(args.dataset, no_color_distor=args.no_color_distor)

    testset, test_conditional_set = testset
    clftrainset, clf_conditional_trainset = clftrainset

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True,
                                            num_workers=args.num_workers, pin_memory=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=args.num_workers,
                                            pin_memory=True)
    clftrainloader = torch.utils.data.DataLoader(clftrainset, batch_size=1000, shuffle=False, num_workers=args.num_workers,
                                                pin_memory=True)
    clf_conditional_trainloader = torch.utils.data.DataLoader(clf_conditional_trainset, batch_size=1000, shuffle=False, num_workers=args.num_workers, pin_memory=True)

    test_conditional_loader = torch.utils.data.DataLoader(test_conditional_set, batch_size=1000, shuffle=False, num_workers=args.num_workers, pin_memory=True)

if args.dataset == 'UTKFace' or args.dataset == 'CelebA':
    args.CelebA_UTKFace = args.dataset

if args.CelebA_UTKFace is not None:
    trainloader = get_loader(dataset=args.CelebA_UTKFace, batch_size= args.batch_size, split = 'train', conditional = True, 
                                shuffle = True, num_workers = args.num_workers, two_crop = True)

    testloader = get_loader(dataset=args.CelebA_UTKFace, batch_size= 500, split = 'test', conditional = False,
                                shuffle = True, num_workers = args.num_workers, two_crop = False)

    clftrainloader = get_loader(dataset=args.CelebA_UTKFace, batch_size= 500, split = 'train', conditional = False,
                                    shuffle = False, num_workers = args.num_workers, two_crop = False)

    clf_conditional_trainloader = get_loader(dataset=args.CelebA_UTKFace, batch_size= 500, split = 'train', conditional = True,
                                                shuffle = False, num_workers = args.num_workers, two_crop = False)

    test_conditional_loader = get_loader(dataset=args.CelebA_UTKFace, batch_size= 500, split = 'test', conditional = True,
                                            shuffle = False, num_workers = args.num_workers, two_crop = False)

# Model
print('==> Building model..')
##############################################################
# Encoder
##############################################################
if args.arch == 'resnet18':
    net = ResNet18(stem=stem)
elif args.arch == 'resnet34':
    net = ResNet34(stem=stem)
elif args.arch == 'resnet50':
    net = ResNet50(stem=stem)
elif args.arch == 'LeNet':
    net = LeNet()
else:
    raise ValueError("Bad architecture specification")
net = net.to(device)

##############################################################
# Critic
##############################################################

critic = LinearCritic(net.representation_dim, temperature=args.temperature, attention_sim= args.attention_sim).to(device)

if args.CelebA_UTKFace == 'UTKFace':
    in_dim = 8
else:
    in_dim = 3

if args.lsh == 'default':
    farehead = FareHead(in_dim = in_dim ,temperature = args.temp_z, lsh = args.lsh, weight_init= args.z_init, projection_mode = args.projection_mode).to(device)
    lsh_attention = SparseAttention(adjacency=args.lsh_adjacency, softmax_temp= args.temp_z, bucket_size=args.batch_size/4).to(device)

if args.lsh == 'off':
    farehead = FareHead(in_dim = in_dim, temperature = args.temp_z, weight_init = args.z_init, lsh = args.lsh).to(device)
    lsh_attention = None

if args.lsh == 'reverse':
    # when lsh is in reverse mode, we don't instantiate the z-critic as this is performed within the SparseAttention module
    lsh_attention = SparseAttention(reverse = True, softmax_temp= args.temp_z, adjacency=args.lsh_adjacency, bucket_size=args.batch_size/4).to(device)
    farehead = None


if device == 'cuda':
    repr_dim = net.representation_dim
    net = torch.nn.DataParallel(net)
    net.representation_dim = repr_dim
    cudnn.benchmark = True

if args.lsh == 'reverse':
    encoder_optimizer = LARS(list(net.parameters()) + list(critic.parameters()), lr=args.lr, eta=1e-3, momentum=args.momentum, weight_decay=1e-4, max_epoch=200)
    attention_optimizer = LARS(list(lsh_attention.parameters()), lr=args.lr, eta=1e-3, momentum=args.momentum, weight_decay= args.attention_weightdecay, max_epoch=200)

else:
    encoder_optimizer = LARS(list(net.parameters()) + list(critic.parameters()) + list(farehead.parameters()), lr=args.lr, eta=1e-3, momentum=args.momentum, weight_decay=1e-4, max_epoch=200)
    attention_optimizer = None

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    resume_from = os.path.join(args.resume)
    checkpoint = torch.load(resume_from)
    net.load_state_dict(checkpoint['net'])
    critic.load_state_dict(checkpoint['critic'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']
    encoder_optimizer.load_state_dict(checkpoint['optim'])
    if args.lsh == 'reverse':
        attention_optimizer.load_state_dict(checkpoint['attention_optim'])

criterion = ConditionalSamplingLoss(lambda_=args.lambda_, temp_z=args.temp_z, scale=args.scale_z, 
                                    weight_clip_threshold=args.weight_clip_threshold, lsh = args.lsh)

if args.cosine_anneal:
    scheduler = CosineAnnealingWithLinearRampLR(encoder_optimizer, args.num_epochs)

# Training
def train(epoch, args, high_threshold, low_threshold):
    print('\nEpoch: %d' % epoch)
    net.train()
    critic.train()
    if args.lsh == 'reverse':
        lsh_attention.train()
    else:
        farehead.train()

    
    train_loss = 0

    t = tqdm(enumerate(trainloader), desc='Loss: **** ', total=len(trainloader), bar_format='{desc}{bar}{r_bar}')
    for batch_idx, (inputs, _, condition) in t:
        if args.CelebA_UTKFace == 'UTKFace':
            condition, _ = condition # condition on the embedding

        x1, x2 = inputs
        x1, x2 = x1.to(device), x2.to(device)
        condition = condition.to(device)
        encoder_optimizer.zero_grad()
        if args.lsh == 'reverse':
            attention_optimizer.zero_grad()
        representation1, representation2 = net(x1), net(x2)
        raw_scores, pseudotargets = critic(representation1, representation2)
        z1 = z2 = condition

        if args.lsh == 'off':
            z_score = farehead(z1, z2)
            loss = criterion(raw_scores, z_score, high_threshold=high_threshold, low_threshold=low_threshold)

        else:
            if args.lsh == 'default':
            # when lsh in default mode, we use critic_z just to apply linear projections and compute sparse attention and conditional attention inside
            # compute_lsh_conditional_attention()
                z1, z2 = farehead(z1, z2)
            
            ca_out = compute_lsh_conditional_attention(z1,z2, raw_scores, lsh_attention) # when lsh in reverse mode, no critic_z projection

            loss = criterion(raw_scores, ca_out, high_threshold=high_threshold, low_threshold=low_threshold)

        loss.backward()
        encoder_optimizer.step()
        if args.lsh == 'reverse':
            attention_optimizer.step()

        train_loss += loss.item()

        t.set_description('Loss: %.3f ' % (train_loss / (batch_idx + 1)))

        #break
    

    return train_loss / (batch_idx + 1)



for epoch in range(start_epoch, start_epoch + args.num_epochs):

    high_threshold, low_threshold = thresholdAnnealing(epoch, args)

    loss = train(epoch, args, high_threshold, low_threshold)

    scalar_logger.log_value(epoch, ('loss', loss),
                                    ('high_threshold', high_threshold),
                                    ('low_threshold', low_threshold),
                                    ('learning_rate', encoder_optimizer.param_groups[0]['lr'])
                                    )

    if (args.test_freq > 0) and (epoch % args.test_freq == (args.test_freq - 1)):
        X, y = encode_train_set(clftrainloader, device, net)
        if args.CelebA_UTKFace == 'CelebA':
            num_classes = 1
        elif args.CelebA_UTKFace == 'UTKFace':
            num_classes = 5
        else:
            num_classes = 10
        clf = train_clf(X, y, net.representation_dim, num_classes, device, reg_weight=args.reg_weight, CelebA_UTKFace = args.CelebA_UTKFace)
        acc = test(testloader, device, net, clf, CelebA_UTKFace = args.CelebA_UTKFace)
        
        if acc > best_acc:
            best_acc = acc
        scalar_logger.log_value(epoch, ('Best acc', best_acc),
                                       ('acc', acc))
        del X, y


        X, y = encode_conditional_train_set(clf_conditional_trainloader, device, net, args.CelebA_UTKFace)
        if args.CelebA_UTKFace == 'UTKFace':
            num_classes = 2
        else:
            num_classes = 3
        clf_conditional = train_clf(X, y, net.representation_dim, num_classes, device, reg_weight=args.reg_weight, continuous=True, CelebA_UTKFace = args.CelebA_UTKFace)
        if args.CelebA_UTKFace == 'CelebA':
            # when celeba on, return two metrics
            mse_loss, mse_loss2, mae_loss = test_conditional(test_conditional_loader, device, net, clf_conditional, CelebA_UTKFace = args.CelebA_UTKFace)
        else:
            mse_loss = test_conditional(test_conditional_loader, device, net, clf_conditional, CelebA_UTKFace = args.CelebA_UTKFace)

        
        if mse_loss > best_mse:
            best_mse = mse_loss
        scalar_logger.log_value(epoch, ('Best MSE', best_mse),
                                       ('MSE', mse_loss))
        del X, y




    if (epoch % args.save_freq == 0) or (epoch == start_epoch + args.num_epochs - 1):
        if (epoch > 0):
            save_checkpoint(net, clf, critic, farehead, lsh_attention, epoch, args.run_name, args, best_acc, scalar_logger, os.path.basename(__file__), encoder_optimizer,
                            attention_optimizer)
    if args.cosine_anneal:
        scheduler.step()
