import argparse
import os
import time

import timm.optim
import torch
import torch.nn as nn

import yaml

import models
import tools
from tools.ops import fast_exp, polar, so_proj, l1_normalize



def get_args():
    parser = argparse.ArgumentParser(
        'Training Globally-Robust Neural Networks')

    parser.add_argument('--config',
                        type=str,
                        default='configs/imagenet.yaml',
                        help='path to the config yaml file')
    # checkpoint saving
    parser.add_argument('--work_dir', default='./checkpoint/', type=str)
    parser.add_argument('--ckpt_prefix', default='', type=str)

    # overwrite
    parser.add_argument('--depth', default=-1, type=int)
    parser.add_argument('--width', default=-1, type=int)

    # distributed training
    parser.add_argument('--launcher',
                        default='pytorch',
                        type=str,
                        help='should be either `slurm` or `pytorch`')
    parser.add_argument('--local_rank', '--local-rank', type=int, default=0)

    return parser.parse_args()



def main():
    args = get_args()

    with open(args.config, 'r') as f:
        cfg = yaml.load(f, Loader=yaml.Loader)

    model_cfg = cfg['model']
    train_cfg = cfg['training']
    dataset_cfg = cfg['dataset']
    gloro_cfg = cfg['gloro']

    if args.depth > 0:
        model_cfg['depth'] = args.depth

    if args.width > 0:
        model_cfg['width'] = args.width

    if args.ckpt_prefix == '':
        depth, width = model_cfg['depth'], model_cfg['width']
        prefix = f"{dataset_cfg['name']}-{depth}x{width}"
        args.ckpt_prefix = prefix

    rank, local_rank, num_gpus = tools.init_DDP(args.launcher)
    print(f'Inited distributed training with {num_gpus} GPUs!')

    if rank == 0:
        os.system(f'cat {args.config}')

    print(f'\nUse checkpoint prefix: {args.ckpt_prefix}')

    train_loader, train_sampler, val_loader, _ = tools.data_loader(
        data_name=dataset_cfg['name'],
        batch_size=train_cfg['batch_size'] // num_gpus,
        num_classes=dataset_cfg['num_classes'],
        input_size=dataset_cfg['input_size'],
        seed=dataset_cfg.get('seed', 2025))  # if seed is not given, use 2025

    model = models.lip_net(**model_cfg, num_classes=dataset_cfg['num_classes'])

    print(model)
    model = model.cuda()

    model = nn.parallel.DistributedDataParallel(model,
                                                device_ids=[local_rank],
                                                output_device=local_rank)


    optimizer = torch.optim.NAdam(model.parameters(), lr=train_cfg['lr'], weight_decay=0.0)
    
    if cfg['training']['lookahead']:
        optimizer = timm.optim.Lookahead(optimizer)
    
    local_lr = 0.1 * (32 / model_cfg['depth']) ** .5 * (3072 / model_cfg['width'])
    scheduler = tools.lr_scheduler(iter_per_epoch=len(train_loader),
                                   max_epoch=train_cfg['epochs'],
                                   warmup_epoch=train_cfg['warmup_epochs'])

    def eps_fn(epoch):
        ratio = min(epoch / train_cfg['epochs'] * 2, 1)
        ratio = gloro_cfg['min_eps'] + (gloro_cfg['max_eps'] -
                                        gloro_cfg['min_eps']) * ratio
        return gloro_cfg['eps'] * ratio

    os.makedirs(args.work_dir, exist_ok=True)

    train_fn = getattr(models, gloro_cfg['loss_type'])


    r = model.module.layers.weights.shape[0] // num_gpus
    layer_index = range(rank * r, rank * r + r)

    layer_m = torch.zeros_like(model.module.layers.weights.data[layer_index])
    layer_v = torch.ones_like(layer_m) / layer_m.shape[-1] ** 2

    global_step = 1
    layer_slow = model.module.layers.weights.data[layer_index].clone()

    layer_gather = torch.zeros_like(model.module.layers.weights.data)
    layer_move = 0

    print('Begin training at ' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    training_logs = []
    start_time = time.time()

    for epoch in range(train_cfg['epochs']):
        eps = eps_fn(epoch)
        train_sampler.set_epoch(epoch)

        model.train()
        correct_vra = correct = total = 0.
        t0 = time.time()
        for idx, (inputs, targets) in enumerate(train_loader):
            bs = inputs.shape[0]
            sub_lipschitz = model.module.sub_lipschitz()

            inputs = inputs.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)

            with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                y, y_, loss = train_fn(model,
                                       x=inputs,
                                       label=targets,
                                       lc=sub_lipschitz,
                                       eps=eps,
                                       return_loss=True)

            loss.backward()


            if train_cfg['grad_clip']:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         train_cfg['grad_clip_val'])

            lr = scheduler.step(optimizer)

            # layers
            grad = model.module.layers.weights.grad[layer_index]
            model.module.layers.weights.grad = None

            X = model.module.layers.weights.data[layer_index]
            grad = so_proj(X, grad)
            # grad = grad / grad.norm(dim=(1, 2), keepdim=True).clamp(min=1e-8)

            layer_m += (grad - layer_m) * 0.1
            layer_v += (grad ** 2 - layer_v) * 1e-3
            update = - layer_m / (layer_v.sqrt() + 1e-8) * lr * local_lr

            layer_move += update

            if epoch < train_cfg['epochs'] // 2 and global_step % 6 == 5:
                layer_move /= 2
                layer_slow = layer_slow.double() @ fast_exp(layer_move.double())
                layer_move.zero_()
                newX = layer_slow.to(newX.dtype)
            else:
                update = fast_exp(update.double())
                newX = (X.double() @ update).to(X.dtype)

            if idx == len(train_loader) - 1:
                newX = polar(newX)
                layer_slow = newX

            layer_gather.zero_()
            layer_gather[layer_index] = newX
            handle = torch.distributed.all_reduce(layer_gather, async_op=True)

            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            global_step += 1
            correct += y.argmax(1).eq(targets)[:bs].sum().item()
            correct_vra += y_.argmax(1).eq(targets)[:bs].sum().item()
            total += bs

            handle.wait()
            model.module.layers.weights.data.copy_(layer_gather)

        if hasattr(optimizer, 'sync_lookahead'):
            optimizer.sync_lookahead()

        layer_m = 0.5 * (layer_m - layer_m.mT)
        layer_v = 0.5 * (layer_v + layer_v.mT)

        val_correct_vra = val_correct = val_total = 0.
        if epoch > train_cfg['epochs'] - 50 or epoch % 5 == 0:
            model.eval()
            model.module.set_iter(num_iter=500)
            sub_lipschitz = model.module.sub_lipschitz().item()
            model.module.set_iter(num_iter=10)


            for inputs, targets in val_loader:
                inputs = inputs.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)
                with torch.no_grad():
                    y, y_, _ = models.trades_loss(model,
                                                  x=inputs,
                                                  label=targets,
                                                  eps=gloro_cfg['eps'],
                                                  lc=sub_lipschitz,
                                                  return_loss=False)

                val_correct += y.argmax(1).eq(targets).sum().item()
                val_correct_vra += y_.argmax(1).eq(targets).sum().item()
                val_total += targets.size(0)
        else:
            sub_lipschitz = model.module.sub_lipschitz().item()
            val_total = 1


        collect_info = [
            correct_vra, correct, total, val_correct_vra, val_correct,
            val_total
        ]
        collect_info = torch.tensor(collect_info,
                                    dtype=torch.float32,
                                    device=inputs.device).clamp_min(1e-9)
        torch.distributed.all_reduce(collect_info)

        acc_train = 100. * collect_info[1] / collect_info[2]
        acc_val = 100. * collect_info[4] / collect_info[5]

        acc_vra_train = 100. * collect_info[0] / collect_info[2]
        acc_vra_val = 100. * collect_info[3] / collect_info[5]

        used = time.time() - start_time
        start_time = time.time()

        string = (f'Epoch {epoch}: '
                  f'Train acc{acc_train: .2f}%,{acc_vra_train: .2f}%; '
                  f'val acc{acc_val: .2f}%,{acc_vra_val: .2f}%. '
                  f'sub_lipschitz:{sub_lipschitz: .2f}. '
                  f'Time:{used / 60: .2f} mins.')

        print(string)
        training_logs.append(string)

    model.module.set_iter(num_iter=500)
    sub_lipschitz = model.module.sub_lipschitz().item()
    for EPS in [36 / 255, 72 / 255, 108 / 255]:
        val_correct_vra = val_correct = val_total = 0.
        for inputs, targets in val_loader:
            inputs = inputs.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)
            with torch.no_grad():
                y, y_, _ = models.trades_loss(model,
                                              x=inputs,
                                              label=targets,
                                              eps=EPS,
                                              lc=sub_lipschitz,
                                              return_loss=False)

            val_correct += y.argmax(1).eq(targets).sum().item()
            val_correct_vra += y_.argmax(1).eq(targets).sum().item()
            val_total += targets.size(0)

        collect_info = [val_correct_vra, val_correct, val_total]
        collect_info = torch.tensor(collect_info,
                                    dtype=torch.float32,
                                    device=inputs.device).clamp_min(1e-9)
        torch.distributed.all_reduce(collect_info)

        acc_val = 100. * collect_info[1] / collect_info[2]
        acc_vra_val = 100. * collect_info[0] / collect_info[2]
        print(f"eps:{int(EPS * 255): d} / 255, acc:{acc_val: .2f}%,{acc_vra_val: .2f}%")

    if rank == 0:
        state = dict(backbone=model.module.state_dict(),
                         training_logs=training_logs,
                         configs=cfg)

        path = f'{args.work_dir}/{args.ckpt_prefix}_{epoch}.pth'
        torch.save(state, path)

    print(f"Finish training at " + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))



if __name__ == '__main__':
    # torch.backends.cuda.preferred_linalg_library('magma')
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    main()

