import os
import torch
import torchvision
import argparse
import numpy as np

from model import load_model, save_model
from modules import NT_Xent, InfoNCE, Transforms, Transforms_imagenet
from utils import mask_correlated_samples
from load_imagenet import imagenet, load_data
import random
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

parser = argparse.ArgumentParser(description='PyTorch Seen Testing Category Training')
parser.add_argument('--batch_size', default=256, type=int, metavar='B', help='training batch size')
parser.add_argument('--workers', default=12, type=int, help='workers')
parser.add_argument('--epochs', default=200, type=int, help='epochs')
parser.add_argument('--save_freq', default=50, type=int, help='save frequency')
parser.add_argument('--resnet', default="resnet18", type=str, help="resnet")
parser.add_argument('--normalize', default=True, action='store_true', help='normalize')
parser.add_argument('--projection_dim', default=128, type=int, help='projection_dim')
parser.add_argument('--lamb', default=1., type=float, help='weight of regularization term')
parser.add_argument('--zeta', default=0.1, type=float, help='variance')
parser.add_argument('--optimizer', default="Adam", type=str, help="optimizer")
parser.add_argument('--lr', default=3e-4, type=float, help='lr')
parser.add_argument('--weight_decay', default=1e-6, type=float, help='weight_decay')
parser.add_argument('--temperature', default=0.5, type=float, help='temperature')
parser.add_argument('--model_dir', default='output/checkpoint/', type=str, help='model save path')
parser.add_argument('--dataset', default='CIFAR10', help='[CIFAR10, CIFAR100, STL-10, tinyImagenet]')
args = parser.parse_args()

def train(train_loader, model, criterion, optimizer):
    loss_epoch = 0.
    for step, ((x_i, x_j), _) in enumerate(train_loader):
        optimizer.zero_grad()
        mu2_i, mu3_i, mu4_i, h2_i, h3_i, h4_i, z_i = model(x_i.cuda())
        mu2_j, mu3_j, mu4_j, h2_j, h3_j, h4_j, z_j = model(x_j.cuda())

        mu2, h2 = torch.cat([mu2_i, mu2_j], dim=0), torch.cat([h2_i, h2_j], dim=0)
        mu3, h3 = torch.cat([mu3_i, mu3_j], dim=0), torch.cat([h3_i, h3_j], dim=0)
        mu4, h4 = torch.cat([mu4_i, mu4_j], dim=0), torch.cat([h4_i, h4_j], dim=0)
        MI_estimitor = 0.25*InfoNCE(mu2, h2) + 0.50*InfoNCE(mu3, h3) + InfoNCE(mu4, h4)

        loss = criterion(z_i, z_j)-args.lamb*MI_estimitor
        loss.backward()
        optimizer.step()

        if step % 50 == 0:
            print(f"Step [{step}/{len(train_loader)}]\t Loss: {loss.item()}")
        loss_epoch += loss.item()
    return loss_epoch

def main():
    data = 'non_imagenet'
    root = "../datasets"
    if args.dataset == "CIFAR10":
        train_dataset = torchvision.datasets.CIFAR10(root, download=True, transform=Transforms())
    elif args.dataset == "CIFAR100":
        train_dataset = torchvision.datasets.CIFAR100(root, download=True, transform=Transforms())
    elif args.dataset == "STL-10":
        train_dataset = torchvision.datasets.STL10(root, split='unlabeled', download=True, transform=Transforms(64))
    elif args.dataset == "tinyImagenet":
        pickle_dir = root + '/tiny_imagenet.pickle'
        train_dataset = load_data(pickle_dir)
        train_dataset = imagenet(train_dataset, transform=Transforms_imagenet(size=224))
        data = 'imagenet'
    else:
        raise NotImplementedError

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
        sampler=None)

    log_dir = "output/log/" + args.dataset + '_LBE/'
    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)

    suffix = args.dataset + '_{}_batch_{}'.format(args.resnet, args.batch_size)
    suffix = suffix + '_proj_dim_{}'.format(args.projection_dim)
    test_log_file = open(log_dir + suffix + '.txt', "w") 

    model, optimizer, scheduler = load_model(args, data=data)
    if args.dataset == 'tinyImagenet':
        model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])
    args.model_dir = args.model_dir + args.dataset + '_LBE/'
    if not os.path.isdir(args.model_dir):
        os.mkdir(args.model_dir)

    mask = mask_correlated_samples(args.batch_size)
    criterion = NT_Xent(args.batch_size, args.temperature, mask)
    for epoch in range(args.epochs):
        loss_epoch = train(train_loader, model, criterion, optimizer)
        if scheduler:
            scheduler.step()
        if (epoch+1) % args.save_freq == 0:
            save_model(args.model_dir+suffix, model, epoch+1)

        print('Epoch {} loss: {}\n'.format(epoch, loss_epoch/len(train_loader)))
        print('Epoch {} loss: {}'.format(epoch, loss_epoch/len(train_loader)), file=test_log_file)
        test_log_file.flush()
    save_model(args.model_dir+suffix, model, args.epochs)

if __name__ == "__main__":
    main()