from pathlib import Path
import argparse
import os
import sys
import random
import subprocess
import time
import json
import math

from PIL import Image
from torch import optim
import torch
import torchvision
import torchvision.transforms as transforms

from utils import count_parameters_in_MB
from trainer import GuidedSimCLR, GuidedSimSiam
from datasets import *

parser = argparse.ArgumentParser(description='RotNet Training')
parser.add_argument('--data', default='./data/stl10',type=Path, metavar='DIR',
                    help='path to dataset')
parser.add_argument('--workers', default=8, type=int, metavar='N',
                    help='number of data loader workers')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--batch-size', default=512, type=int, metavar='N',
                    help='mini-batch size')
parser.add_argument('--learning-rate', default=4.8, type=float, metavar='LR',
                    help='base learning rate')
parser.add_argument('--LUT_lr', default=0,
                    help='multistep to decay learning rate')
parser.add_argument('--weight-decay', default=1e-6, type=float, metavar='W',
                    help='weight decay')
parser.add_argument('--print-freq', default=10, type=int, metavar='N',
                    help='print frequency')
parser.add_argument('--checkpoint-dir', default='./experiments/guided_escnn18',type=Path,
                    metavar='DIR', help='path to checkpoint directory')
parser.add_argument('--rotation', default=0.0, type=float,
                    help="coefficient of rotation loss")
parser.add_argument('--alpha', default=0.0, type=float)
parser.add_argument('--beta', default=0.1, type=float)
parser.add_argument('--arch', default='resnet18', type=str, help='model architecture',
                    choices=['resnet18', 'resnet50', 'escnn18', 'escnn50', 'nin'])
parser.add_argument('--connector', default='resnet18', type=str, help='equivariance connection map',
                    choices=['softmax', 'identity', 'tanh', 'shift'])
parser.add_argument('--pretrain-set', default='stl10', type=str, help='pretraining dataset',
                    choices=['stl10', 'stl10-R', 'imagenet100', 'caltech256', 'cifar10', 'cifar10-essl'])
parser.add_argument('--use_gpool', action='store_true')
parser.add_argument('--ssl', default='simclr', type=str, help='ssl method',
                    choices=['simclr', 'simsiam', 'moco'])
parser.add_argument('--stop_gradient', default=False, action='store_true')
parser.add_argument('--loader_drop_last', default=False, action='store_true')
parser.add_argument('--moco_projector', default=False, action='store_true')
parser.add_argument('--rotation_degree', default=0, type=float)
parser.add_argument('--affine', default=False, action='store_true')
parser.add_argument('--N', default=4, type=int)
parser.add_argument('--circular_transform', default=False, action='store_true')
parser.add_argument('--circular_range', default=10, type=int)
parser.add_argument('--cifar_layer', default=3, type=int, choices=[0,1,2,3])


def main():
    args = parser.parse_args()
    args.ngpus_per_node = torch.cuda.device_count()

    args.rank = 0
    args.dist_url = f'tcp://localhost:{random.randrange(49152, 65535)}'
    args.world_size = args.ngpus_per_node
    torch.multiprocessing.spawn(main_worker, (args,), args.ngpus_per_node)


def main_worker(gpu, args):
    args.rank += gpu
    torch.distributed.init_process_group(
        backend='nccl', init_method=args.dist_url,
        world_size=args.world_size, rank=args.rank)

    args.checkpoint_dir = args.checkpoint_dir
    if args.rank == 0:
        args.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        stats_file = open(args.checkpoint_dir / 'stats.txt', 'a', buffering=1)
        print(' '.join(sys.argv))
        print(' '.join(sys.argv), file=stats_file)

    torch.cuda.set_device(gpu)
    torch.backends.cudnn.benchmark = True

    if args.ssl=='simclr':
        model = GuidedSimCLR(args).cuda(gpu)
    elif args.ssl=='simsiam':
        model = GuidedSimSiam(args).cuda(gpu)
    elif args.ssl=='moco':
        model = GuidedMocoV2(args).cuda(gpu)
    if args.rank == 0:
        if args.ssl == 'moco':
            print('encoder params: {}M'.format(count_parameters_in_MB(model.backbone_q)+count_parameters_in_MB(model.backbone_k)))
            print('encoder params: {}M'.format(count_parameters_in_MB(model.backbone_q)+count_parameters_in_MB(model.backbone_k)), file=stats_file)
        else:
            print('encoder params: {}M'.format(count_parameters_in_MB(model.backbone)))
            print('encoder params: {}M'.format(count_parameters_in_MB(model.backbone)), file=stats_file)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])

    if args.ssl=='simclr':
        if args.pretrain_set != 'cifar10':
            # print('LARS')
            optimizer = LARS(model.parameters(), lr=0, weight_decay=args.weight_decay,
                            weight_decay_filter=exclude_bias_and_norm,
                            lars_adaptation_filter=exclude_bias_and_norm)
        else:
            # print('SGD')
            optimizer = torch.optim.SGD(model.parameters(), args.learning_rate,
                                        momentum=0.9,
                                        weight_decay=5e-4,
                                        nesterov=True)
    elif args.ssl=='simsiam':
        optimizer = torch.optim.SGD(model.parameters(), args.learning_rate,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif args.ssl=='moco':
        optimizer = torch.optim.SGD(model.parameters(), args.learning_rate,
                                    momentum=0.9,
                                    weight_decay=1e-4)

    # automatically resume from checkpoint if it exists
    if (args.checkpoint_dir / 'checkpoint.pth').is_file():
        ckpt = torch.load(args.checkpoint_dir / 'checkpoint.pth',
                          map_location='cpu')
        start_epoch = ckpt['epoch']
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
    else:
        start_epoch = 0

    dataset = load_pretrain_datasets(args)
    sampler = torch.utils.data.distributed.DistributedSampler(dataset, drop_last=True)
    assert args.batch_size % args.world_size == 0
    per_device_batch_size = args.batch_size // args.world_size
    
    if args.loader_drop_last:
        loader = torch.utils.data.DataLoader(
            dataset, batch_size=per_device_batch_size, num_workers=args.workers,
            pin_memory=True, sampler=sampler, drop_last=True)
    else:
        loader = torch.utils.data.DataLoader(
            dataset, batch_size=per_device_batch_size, num_workers=args.workers,
            pin_memory=True, sampler=sampler)

    start_time = time.time()
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(start_epoch, args.epochs):
        sampler.set_epoch(epoch)

        for step, ((y1, y2), _) in enumerate(loader, start=epoch * len(loader)):
            y1 = y1.cuda(gpu, non_blocking=True)
            y2 = y2.cuda(gpu, non_blocking=True)

            if args.LUT_lr == 0:
                # print('without LUT')
                lr = adjust_learning_rate(args, optimizer, loader, step)
            else:
                # print('with LUT')
                lr = adjust_learning_rate_LUT(optimizer, epoch, args.LUT_lr)
            optimizer.zero_grad()
            
            with torch.cuda.amp.autocast():
                loss, con_loss, ori_loss = model.forward(y1, y2, beta=args.beta)
        
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            if step % args.print_freq == 0:
                if args.rank == 0:
                    
                    print(f'epoch={epoch}, step={step}, loss={loss.item()}, con_loss={con_loss.item()}, ori_loss={ori_loss.item()}')
                    stats = dict(epoch=epoch, step=step, learning_rate=lr,
                                loss=loss.item(), 
                                con_loss=con_loss.item(),
                                ori_loss=ori_loss.item(),
                                time=int(time.time() - start_time))
                    print(json.dumps(stats), file=stats_file)
        if args.rank == 0:
            # save checkpoint
            state = dict(epoch=epoch + 1, model=model.state_dict(),
                        optimizer=optimizer.state_dict())
            torch.save(state, args.checkpoint_dir / 'checkpoint.pth')
        
        # if epoch % 100 == 99:
        # if True: # tmp code
        #     if args.ssl == 'moco':
        #         state_dict = dict(backbone_q=model.module.backbone_q.state_dict(),
        #                         backbone_k=model.module.backbone_k.state_dict(),
        #                         projector_q=model.module.projector_q.state_dict(),
        #                         projector_k=model.module.projector_k.state_dict())
        #     else:
        #         state_dict = dict(backbone=model.module.backbone.state_dict(),
        #                         projector=model.module.projector.state_dict())

        #     if args.alpha or args.beta:
        #         if args.ssl == 'moco':
        #             state_dict['predictor_eqv_q'] = model.module.predictor_eqv_q.state_dict()
        #             state_dict['predictor_eqv_k'] = model.module.predictor_eqv_k.state_dict()
        #         else:
        #             state_dict['predictor_eqv'] = model.module.predictor_eqv.state_dict()
                
        #     torch.save(state_dict, args.checkpoint_dir / f'final_{epoch}.pth')
            
        
    if args.rank == 0:
        # save final model
        if args.ssl == 'moco':
            state_dict = dict(backbone_q=model.module.backbone_q.state_dict(),
                            backbone_k=model.module.backbone_k.state_dict(),
                            projector_q=model.module.projector_q.state_dict(),
                            projector_k=model.module.projector_k.state_dict())
        else:
            state_dict = dict(backbone=model.module.backbone.state_dict(),
                            projector=model.module.projector.state_dict())
                    
        if args.alpha or args.beta:
            if args.ssl == 'moco':
                state_dict['predictor_eqv_q'] = model.module.predictor_eqv_q.state_dict()
                state_dict['predictor_eqv_k'] = model.module.predictor_eqv_k.state_dict()
            else:
                state_dict['predictor_eqv'] = model.module.predictor_eqv.state_dict()

        torch.save(state_dict, args.checkpoint_dir / 'final.pth')


def adjust_learning_rate(args, optimizer, loader, step):
    max_steps = args.epochs * len(loader)
    warmup_steps = 10 * len(loader)
    base_lr = args.learning_rate #* args.batch_size / 256
    if step < warmup_steps:
        lr = base_lr * step / warmup_steps
    else:
        step -= warmup_steps
        max_steps -= warmup_steps
        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
        end_lr = base_lr * 0.001
        lr = base_lr * q + end_lr * (1 - q)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

def adjust_learning_rate_LUT(optimizer, iters, LUT):
    # decay learning rate by 'gamma' for every 'stepsize'
    for (stepvalue, base_lr) in LUT:
        if iters < stepvalue:
            lr = base_lr
            break

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


class LARS(optim.Optimizer):
    def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001,
                 weight_decay_filter=None, lars_adaptation_filter=None):
        defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
                        eta=eta, weight_decay_filter=weight_decay_filter,
                        lars_adaptation_filter=lars_adaptation_filter)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for g in self.param_groups:
            for p in g['params']:
                dp = p.grad

                if dp is None:
                    continue

                if g['weight_decay_filter'] is None or not g['weight_decay_filter'](p):
                    dp = dp.add(p, alpha=g['weight_decay'])

                if g['lars_adaptation_filter'] is None or not g['lars_adaptation_filter'](p):
                    param_norm = torch.norm(p)
                    update_norm = torch.norm(dp)
                    one = torch.ones_like(param_norm)
                    q = torch.where(param_norm > 0.,
                                    torch.where(update_norm > 0,
                                                (g['eta'] * param_norm / update_norm), one), one)
                    dp = dp.mul(q)

                param_state = self.state[p]
                if 'mu' not in param_state:
                    param_state['mu'] = torch.zeros_like(p)
                mu = param_state['mu']
                mu.mul_(g['momentum']).add_(dp)

                p.add_(mu, alpha=-g['lr'])


def exclude_bias_and_norm(p):
    return p.ndim == 1

if __name__ == '__main__':
    main()
