from pathlib import Path
import argparse
import os
import sys
import random
import subprocess
import time
import json
import math
import numpy as np

from PIL import Image, ImageOps, ImageFilter
from torch import nn, optim
import torch
import torchvision
import torchvision.transforms as transforms

from utils import * 
from augmentation import Multimask_Transform

parser = argparse.ArgumentParser(description='SimCLR Training')
parser.add_argument('--data', type=Path, metavar='DIR',
                    help='path to dataset')
parser.add_argument('--workers', default=8, type=int, metavar='N',
                    help='number of data loader workers')
parser.add_argument('--epochs', default=100, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--warm-up', default=10, type=int, help="warm up epochs")
parser.add_argument('--batch-size', default=4096, type=int, metavar='N',
                    help='mini-batch size')
parser.add_argument('--learning-rate', default=0.3, type=float, metavar='LR',
                    help='base learning rate')
parser.add_argument('--weight-decay', default=1e-6, type=float, metavar='W',
                    help='weight decay')
parser.add_argument('--print-freq', default=10, type=int, metavar='N',
                    help='print frequency')
parser.add_argument('--checkpoint-dir', default='./', type=Path,
                    metavar='DIR', help='path to checkpoint directory')
parser.add_argument('--name', type=str, default='test')

# simclr parameters
parser.add_argument('--temp', default=0.2, type=float)

## masking hyperparameters
parser.add_argument("--mask-ratio", type=float, default=0.3)
parser.add_argument("--grid-size", type=int, default=32)
parser.add_argument("--focal", type=float, default=0.2) 
parser.add_argument("--color", type=float, default=0.7)


def main():
    args = parser.parse_args()
    args.ngpus_per_node = torch.cuda.device_count()
    if 'SLURM_JOB_ID' in os.environ:
        cmd = 'scontrol show hostnames ' + os.getenv('SLURM_JOB_NODELIST')
        stdout = subprocess.check_output(cmd.split())
        host_name = stdout.decode().splitlines()[0]
        args.rank = int(os.getenv('SLURM_NODEID')) * args.ngpus_per_node
        args.world_size = int(os.getenv('SLURM_NNODES')) * args.ngpus_per_node
        args.dist_url = f'tcp://{host_name}:58478'
    else:
        args.rank = 0
        args.dist_url = f'tcp://localhost:{random.randrange(49152, 65535)}'
        args.world_size = args.ngpus_per_node
    torch.multiprocessing.spawn(main_worker, (args,), args.ngpus_per_node)


def main_worker(gpu, args):
    args.rank += gpu
    torch.distributed.init_process_group(
        backend='nccl', init_method=args.dist_url,
        world_size=args.world_size, rank=args.rank)

    args.checkpoint_dir = args.checkpoint_dir / args.name
    if args.rank == 0:
        args.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        for k in vars(args):
            print(k, vars(args)[k])

    torch.cuda.set_device(gpu)
    torch.backends.cudnn.benchmark = True

    model = SimCLR(args).cuda(gpu)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])

    optimizer = LARS(model.parameters(), lr=0, weight_decay=args.weight_decay,
                     weight_decay_filter=True,
                     lars_adaptation_filter=True)

    if (args.checkpoint_dir / 'checkpoint.pth').is_file():
        ckpt = torch.load(args.checkpoint_dir / 'checkpoint.pth',
                          map_location='cpu')
        start_epoch = ckpt['epoch']
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
    else:
        start_epoch = 0

    tf = Multimask_Transform(args)

    dataset = torchvision.datasets.ImageFolder(args.data / 'train', tf)
    sampler = torch.utils.data.distributed.DistributedSampler(dataset, drop_last=True)
    assert args.batch_size % args.world_size == 0
    per_device_batch_size = args.batch_size // args.world_size
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=per_device_batch_size, num_workers=args.workers,
        pin_memory=True, sampler=sampler)

    start_time = time.time()
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(start_epoch, args.epochs):
        sampler.set_epoch(epoch)

        for step, ((y1, y2, y3, y4), labels) in enumerate(loader, start=epoch * len(loader)):
            y1 = y1.cuda(gpu, non_blocking=True)
            y2 = y2.cuda(gpu, non_blocking=True)
            y3 = y3.cuda(gpu, non_blocking=True)
            y4 = y4.cuda(gpu, non_blocking=True)

            lr = adjust_learning_rate(args, optimizer, loader, step)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                loss, acc = model.forward(y1, y2, y3, y4, labels)
        
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            if step % args.print_freq == 0:
                torch.distributed.reduce(acc.div_(args.world_size), 0)
                if args.rank == 0:
                    stats = dict(epoch=epoch, step=step,
                                 acc=acc.item() * 100,
                                 lr=lr,
                                 loss=loss.item(),
                                 time=int(time.time() - start_time))
                    print(json.dumps(stats))
            
        if args.rank == 0:
            # save checkpoint
            state = dict(epoch=epoch + 1, model=model.state_dict(),
                         optimizer=optimizer.state_dict())
            torch.save(state, args.checkpoint_dir / ('checkpoint.pth'))
        
        if args.rank == 0:
            # save final model
            torch.save(dict(backbone=model.module.backbone.state_dict(),
                            online_head=model.module.online_head.state_dict()),
                    args.checkpoint_dir + args.name + '.pth')


def adjust_learning_rate(args, optimizer, loader, step):
    max_steps = args.epochs * len(loader)
    base_lr = args.learning_rate * args.batch_size / 256

    warmup_steps = args.warm_up * len(loader)
    if step < warmup_steps:
        lr = base_lr * step / warmup_steps
    else:
        step -= warmup_steps
        max_steps -= warmup_steps
        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
        end_lr = base_lr * 0.001
        lr = base_lr * q + end_lr * (1 - q)
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


class SimCLR(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.backbone = torchvision.models.resnet50(zero_init_residual=True)
        self.backbone.fc = nn.Identity()

        self.online_head = nn.Linear(2048, 1000)

        sizes = [2048, 2048, 256]
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i+1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        layers.append(nn.BatchNorm1d(sizes[-1])) ## NNCLR
        self.projector = nn.Sequential(*layers)

    def forward(self, y1, y2, y3, y4, labels):
        r1 = self.backbone(y1)
        r2 = self.backbone(y2)
        r3 = self.backbone(y3)
        r4 = self.backbone(y4)

        z1 = self.projector(r1)
        z2 = self.projector(r2)
        z3 = self.projector(r3)
        z4 = self.projector(r4)

        z1 = torch.nn.functional.normalize(z1, dim=1)
        z2 = torch.nn.functional.normalize(z2, dim=1)
        z3 = torch.nn.functional.normalize(z3, dim=1)
        z4 = torch.nn.functional.normalize(z4, dim=1)
        
        z1 = gather_from_all(z1)
        z2 = gather_from_all(z2)
        z3 = gather_from_all(z3)
        z4 = gather_from_all(z4)

        loss = infoNCE(z1, z3, temperature=self.args.temp) / 8 + infoNCE(z2, z3, temperature=self.args.temp) / 8 \
            + infoNCE(z3, z1, temperature=self.args.temp) / 8 + infoNCE(z3, z2, temperature=self.args.temp) / 8 \
            + infoNCE(z1, z4, temperature=self.args.temp) / 8 + infoNCE(z2, z4, temperature=self.args.temp) / 8 \
            + infoNCE(z4, z1, temperature=self.args.temp) / 8 + infoNCE(z4, z2, temperature=self.args.temp) / 8

        r = (r1 + r2) / 2
        logits = self.online_head(r.detach())
        cls_loss = torch.nn.functional.cross_entropy(logits, labels)
        acc = torch.sum(torch.eq(torch.argmax(logits, dim=1), labels)) / logits.size(0)

        loss = loss + cls_loss

        return loss, acc


def infoNCE(nn, p, temperature=0.1):
    logits = nn @ p.T
    logits /= temperature
    n = p.shape[0]
    labels = torch.arange(0, n, dtype=torch.long).cuda()
    loss = torch.nn.functional.cross_entropy(logits, labels)
    return loss


if __name__ == '__main__':
    main()
