"""
SimCLR Contrastive Learning Training

Train ANN or SNN models using SimCLR (Simple Contrastive Learning of Representations).

Usage:
    # Train ANN
    python main_simclr.py --arch resnet18 --dataset-name cifar10
    
    # Train SNN
    python main_simclr.py --spiking --arch resnet18 --timesteps 4 --dataset-name cifar10
"""

import os
import sys
import argparse
import logging
from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter
import torch
import torch.backends.cudnn as cudnn

from data_aug.contrastive_learning_dataset import ContrastiveLearningDataset, ContrastiveLearningDatasetSNN
from util import save_config_file, save_checkpoint
from networks.resnet_ann import SupConResNet
from networks.resnet_snn import SupConResNetSNN
from losses import SimCLRLoss, SimCLRSNNLoss


def parse_option():
    parser = argparse.ArgumentParser('SimCLR Contrastive Learning')

    # Model settings
    parser.add_argument('--spiking', action='store_true', help='Use SNN model')
    parser.add_argument('--timesteps', type=int, default=4, help='Number of timesteps for SNN')
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
                        choices=['resnet18', 'resnet34', 'resnet50'])
    
    # Dataset settings
    parser.add_argument('--dataset-name', type=str, default='cifar10',
                        choices=['cifar10', 'tinyimagenet'])
    parser.add_argument('-j', '--workers', default=16, type=int,
                        help='Number of data loading workers')
    
    # Training settings
    parser.add_argument('--epochs', default=200, type=int, help='Number of training epochs')
    parser.add_argument('-b', '--batch-size', default=512, type=int, help='Batch size')
    parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float, dest='lr',
                        help='Learning rate')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, dest='weight_decay',
                        help='Weight decay')
    
    # SimCLR hyperparameters
    parser.add_argument('--temperature', default=0.07, type=float,
                        help='Temperature for contrastive loss')
    parser.add_argument('--n-views', default=2, type=int,
                        help='Number of augmented views per image')
    
    # Logging settings
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--log-every-n-steps', default=100, type=int)
    parser.add_argument('--save-path', type=str, default='./save/SimCLR')
    parser.add_argument('--gpu-ids', type=str, default='0,1,2,3')

    opt = parser.parse_args()
    
    # Set dataset paths
    if opt.dataset_name == 'cifar10':
        opt.data_folder = "~/projects/Datasets/CIFAR10/"
        opt.img_size = 32
    elif opt.dataset_name == 'tinyimagenet':
        opt.data_folder = "~/projects/Datasets/tiny-imagenet-200/"
        opt.img_size = 64

    # Set log path
    if opt.spiking:
        opt.log_path = os.path.join(
            opt.save_path, 
            f'SNN_{opt.dataset_name}_{opt.arch}_T{opt.timesteps}_bsz{opt.batch_size}_lr{opt.lr}'
        )
    else:
        opt.log_path = os.path.join(
            opt.save_path,
            f'ANN_{opt.dataset_name}_{opt.arch}_bsz{opt.batch_size}_lr{opt.lr}'
        )
    
    if not os.path.isdir(opt.log_path):
        os.makedirs(opt.log_path)

    return opt


def set_model(opt):
    """Initialize model and loss function."""
    if opt.spiking:
        model = SupConResNetSNN(name=opt.arch, timestep=opt.timesteps)
        criterion = SimCLRSNNLoss(
            timestep=opt.timesteps, 
            batch_size=opt.batch_size, 
            n_views=opt.n_views, 
            temperature=opt.temperature
        )
    else:
        model = SupConResNet(name=opt.arch)
        criterion = SimCLRLoss(
            batch_size=opt.batch_size, 
            n_views=opt.n_views, 
            temperature=opt.temperature
        )

    model.encoder = torch.nn.DataParallel(model.encoder)
    model = model.cuda()
    criterion = criterion.cuda()
    cudnn.benchmark = True

    return model, criterion


def train(model, train_loader, simclr, optimizer, scheduler, opt, writer):
    """Main training loop."""
    model.train()
    n_iter = 0
    
    for epoch_counter in range(opt.epochs):
        for images, _ in tqdm(train_loader, desc=f'Epoch {epoch_counter}'):
            if opt.spiking:
                # Stack temporal views for SNN
                images_times = torch.zeros(
                    opt.n_views * opt.batch_size, opt.timesteps, 3, opt.img_size, opt.img_size
                )
                for i in range(opt.timesteps):
                    image_per_time = torch.cat([images[0][i], images[1][i]], dim=0)
                    images_times[:, i, ...] = image_per_time
                images_times = images_times.cuda()
                features = model(images_times)
            else:
                images = torch.cat(images, dim=0)
                images = images.cuda()
                features = model(images)

            loss, top1, top5 = simclr(features)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if n_iter % opt.log_every_n_steps == 0:
                writer.add_scalar('loss', loss, global_step=n_iter)
                writer.add_scalar('acc/top1', top1[0], global_step=n_iter)
                writer.add_scalar('acc/top5', top5[0], global_step=n_iter)
                writer.add_scalar('learning_rate', scheduler.get_last_lr()[0], global_step=n_iter)
            n_iter += 1
        
        # Warmup for first 10 epochs
        if epoch_counter >= 10:
            scheduler.step()
        
        logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")

    logging.info("Training has finished.")
    
    # Save final checkpoint
    checkpoint_name = f'checkpoint_{opt.epochs:04d}.pth.tar'
    save_checkpoint({
        'epoch': opt.epochs,
        'arch': opt.arch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }, is_best=False, filename=os.path.join(writer.log_dir, checkpoint_name))
    logging.info(f"Model checkpoint saved at {writer.log_dir}.")


def main():
    opt = parse_option()
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
    writer = SummaryWriter(log_dir=opt.log_path)

    logging.basicConfig(
        filename=os.path.join(writer.log_dir, 'training.log'), 
        level=logging.DEBUG
    )
    save_config_file(writer.log_dir, opt)

    # Setup dataset
    if opt.spiking:
        dataset = ContrastiveLearningDatasetSNN(
            opt.data_folder, opt.dataset_name, timesteps=opt.timesteps
        )
    else:
        dataset = ContrastiveLearningDataset(opt.data_folder, opt.dataset_name)
    train_dataset = dataset.get_dataset(opt.n_views)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size, shuffle=True,
        num_workers=opt.workers, pin_memory=True, drop_last=True
    )

    # Setup model and optimizer
    model, simclr = set_model(opt)
    optimizer = torch.optim.Adam(model.parameters(), opt.lr, weight_decay=opt.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1
    )

    logging.info(f"Start SimCLR training for {opt.epochs} epochs.")
    train(model, train_loader, simclr, optimizer, scheduler, opt, writer)


if __name__ == '__main__':
    main()
