import torch
import time
import sys
import os
import logging
from tensorboardX import SummaryWriter
import torch

from network import create_network
from utils import ensure_dir, get_args, load_checkpoint, save_checkpoint, get_train_loader


def main():
    args = get_args()
    if args.local_rank == 0:
        ensure_dir(args.output_dir)
        ensure_dir(args.tb_dir)

    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend="nccl", init_method="env://")

    if args.local_rank == 0:
        log_format = "[%(asctime)s] %(message)s"
        logging.basicConfig(
            stream=sys.stdout,
            level=logging.INFO,
            format=log_format,
            datefmt="%d %I:%M:%S",
        )
        t = time.time()
        local_time = time.localtime(t)
        fh = logging.FileHandler(os.path.join(f"{args.output_dir}/log.txt"))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        logging.info(args)

    args.iter = 0
    args.device = torch.device("cuda")
    args.model = create_network()
    args.model.to(args.device)
    for name, p in args.model.named_parameters():
        if "conv" in name:
            p.data.mul_(args.enlarge_factor)

    params_conv = []
    params_others = []
    for name, p in args.model.named_parameters():
        if 'conv' in name:
            params_conv.append(p)
        else:
            params_others.append(p)
    param_groups = [
        {
            "params": params_conv,
            "lr": args.learning_rate * (args.enlarge_factor)**2,
            "weight_decay": args.weight_decay / (args.enlarge_factor)**2, 
        },
        {
            "params": params_others,
        }
    ]
    args.optimizer = torch.optim.SGD(
        param_groups,
        lr=args.learning_rate,
        weight_decay=args.weight_decay,
        momentum=args.momentum,
    )

    iters_per_epoch = 1280000 // args.batch_size
    max_iters = args.Epochs * iters_per_epoch
    args.scheduler = torch.optim.lr_scheduler.MultiStepLR(
        args.optimizer,
        [30*iters_per_epoch, 60*iters_per_epoch, 80*iters_per_epoch],
        gamma=0.1
    )

    if os.path.exists(args.output_dir + "/model.ckpt"):
        load_checkpoint(args, args.output_dir + "/model.ckpt")
    elif args.local_rank == 0:
        save_checkpoint(args, args.output_dir + "/start.ckpt")
    args.model = torch.nn.parallel.DistributedDataParallel(
        args.model, device_ids=[args.local_rank], broadcast_buffers=False
    )

    data_loader = get_train_loader(args.local_rank, args.batch_size, args.gpu_nums, args.train_dataset_dir)
    if args.local_rank == 0:
        tb_writer = SummaryWriter(args.tb_dir)
    args.model.train()
    for iteration in range(args.iter+1, max_iters):
        data, label = data_loader.next()
        data = data.cuda()
        label = label.type(torch.long).cuda()
        loss_dict, output_dict = args.model(data, label)
        loss = loss_dict["loss"]
        args.optimizer.zero_grad()
        loss.backward()
        args.optimizer.step()
        
        if args.local_rank == 0:
            if iteration % 20 == 0 or iteration == max_iters:
                log_str = "it:%d, lr:%.1e, " % (
                    iteration,
                    args.optimizer.param_groups[0]["lr"],
                )
                for key in ["Loss", "Err1", "Err5"]:
                    tb_writer.add_scalar(key, output_dict[key], global_step=iteration)
                    log_str += key + ": %.3f, " % float(output_dict[key])
                logging.info(log_str)

                for name, m in args.model.named_modules():
                    if isinstance(m, torch.nn.Conv2d):
                        weight_norm = m.weight.data.norm().item()
                        if args.momentum == 0:
                            update = m.weight.grad.data.norm().item()
                        else:
                            update = args.optimizer.state[m.weight]["momentum_buffer"].norm().item()
                        au = update * args.optimizer.param_groups[0]['lr'] / weight_norm

                        tb_writer.add_scalar(f"{name}/au", au, global_step=iteration)
                        tb_writer.add_scalar(f"{name}/norm", weight_norm, global_step=iteration)

        args.scheduler.step()
        if iteration % iters_per_epoch == 0 or iteration == max_iters:
            args.iter = iteration
            if args.local_rank == 0:
                save_checkpoint(args, args.output_dir + "/iter-{}.pth".format(iteration))
                if os.path.exists(f"{args.output_dir}/model.ckpt"):
                     os.remove(f"{args.output_dir}/model.ckpt")
                os.symlink(f"{args.output_dir}/iter-{iteration}.pth", f"{args.output_dir}/model.ckpt")

if __name__ == "__main__":
    main()
    os._exit(0)
