import argparse
import logging
import os
import sys
import time

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm
from utils import (
    AvgrageMeter,
    accuracy,
    save_checkpoint,
)
from tensorboardX import SummaryWriter
from resnet import resnet18

def get_args():
    parser = argparse.ArgumentParser("ResNet-18 on cifar10")
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--learning_rate", type=float, default=0.5, help="init learning rate")
    parser.add_argument("--weight_decay", type=float, default=1e-3, help="weight decay")
    parser.add_argument("--momentum", type=float, default=0, help="momentum")
    parser.add_argument("--dataset_dir", type=str, default="./output/dataset", help="dir of dataset")
    parser.add_argument("--output_dir", type=str, default="./output", help="dir of output")
    parser.add_argument("--exp_type", type=str, default="standard", choices=["standard", "rescale", "init_equiv"], help="type of experiments")
    parser.add_argument("--enlarge_factor", type=float, default=10, help="factor to enlarge the intialized value of weight norm while preserving the evolution of angular update") 

    args = parser.parse_args()
    return args


def main():
    args = get_args()
    args.output_dir = f"{args.output_dir}/{args.exp_type}"
    if not os.path.exists("{}".format(args.output_dir)):
        os.makedirs("{}".format(args.output_dir))

    log_format = "[%(asctime)s] %(message)s"
    logging.basicConfig(
        stream=sys.stdout, level=logging.INFO, 
        format=log_format, datefmt="%d %I:%M:%S"
    )
    fh = logging.FileHandler(os.path.join(args.output_dir, "log.txt"))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info(args)
    tb_writer = SummaryWriter(args.output_dir+'/tb_dir')

    CIFAR10_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    CIFAR10_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

    train_aug = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize(CIFAR10_TRAIN_MEAN, CIFAR10_TRAIN_STD),
        ]
    )
    if not os.path.exists(args.dataset_dir):
        os.makedirs(args.dataset_dir)
    train_dataset = datasets.CIFAR10(
        args.dataset_dir, train=True, download=True, transform=train_aug
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True, pin_memory=True, drop_last=True
    )
    args.train_loader = train_loader

    test_aug = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(CIFAR10_TRAIN_MEAN, CIFAR10_TRAIN_STD),]
    )
    val_dataset = datasets.CIFAR10(
        args.dataset_dir, train=False, download=True, transform=test_aug
    )
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=200, shuffle=False, num_workers=4, pin_memory=True)
    args.val_loader = val_loader

    print("load data successfully")

    model = resnet18() 
    model = model.cuda()

    if args.exp_type == "init_equiv":
        for name, p in model.named_parameters():
            if "conv" in name:
                p.data.mul_(args.enlarge_factor)

    params_conv = []
    params_others = []
    for name, p in model.named_parameters():
        if 'conv' in name:
            params_conv.append(p)
        else:
            params_others.append(p)

    if args.exp_type != "init_equiv":
        args.enlarge_factor = 1
    
    param_groups = [
        {
            "params": params_conv,
            "lr": args.learning_rate * (args.enlarge_factor)**2,
            "weight_decay": args.weight_decay / (args.enlarge_factor)**2, 
        },
        {
            "params": params_others,
        }
    ]
    optimizer = torch.optim.SGD(
        param_groups,
        lr=args.learning_rate,
        weight_decay=args.weight_decay,
        momentum=args.momentum,
    )

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[60, 120, 160],
        gamma=1
    )
    iter_per_epoch = len(train_loader)
    start_iter = 0

    args.loss_function = nn.CrossEntropyLoss()
    args.optimizer = optimizer
    args.scheduler = scheduler

    logging.info("Training begin!")
    train(model, args, start_iter + 1, tb_writer)

def train(model, args, start_iter, tb_writer):
    optimizer = args.optimizer
    scheduler = args.scheduler
    loss_function = args.loss_function
    Iters = len(args.train_loader)
    for epoch in range(start_iter, 201):
        model.train()
        Top1, Top5, Loss = 0.0, 0.0, 0.0
        pbar = tqdm(total=Iters)
        if args.exp_type == "rescale" and epoch == 61:
            for p in model.parameters():
                p.mul_(0.2**0.25)
        for i, (data, label) in enumerate(args.train_loader):
            data, label = data.cuda(), label.cuda()
            output = model(data)
            loss = loss_function(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            prec1, prec5 = accuracy(output, label, topk=(1, 5))
            Top1 += prec1.item()
            Top5 += prec5.item()
            Loss += loss.item()
            pbar.update()
            pbar.set_description(
                'Loss {:.2f}, Top-1 {:.2f}, Top-5 {:.2f}'.format(
                    Loss/(i+1), Top1/(i+1), Top5/(i+1),
                )
            )
            if i%20 == 1:
                for name, m in model.named_modules():
                    if isinstance(m, nn.Conv2d):
                        weight_norm = m.weight.data.norm().item()
                        if args.momentum == 0:
                            update = m.weight.grad.data.norm().item()
                        else:
                            update = args.optimizer.state[m.weight]["momentum_buffer"].norm().item()
                        au = update * args.optimizer.param_groups[0]['lr'] / weight_norm

                        tb_writer.add_scalar(f"{name}/au", au, global_step=(epoch-1)*Iters+i)
                        tb_writer.add_scalar(f"{name}/norm", weight_norm, global_step=(epoch-1)*Iters+i)

        tb_writer.add_scalar(
            'train/Loss', Loss/Iters, global_step=epoch
        )
        tb_writer.add_scalar(
            'train/Top-1', Top1/Iters, global_step=epoch
        )
        tb_writer.add_scalar(
            'train/Top-5', Top5/Iters, global_step=epoch
        )

        printInfo = (
            "TRAIN Epoch {}: lr = {:.6f},\tloss = {:.6f},\t".format(epoch, scheduler.get_lr()[0], Loss / Iters)
            + "Top-1 = {:.6f},\t".format(Top1 / Iters)
            + "Top-5 = {:.6f},\t".format(Top5 / Iters)
        )
        logging.info(printInfo)
        scheduler.step()
        save_checkpoint(model, epoch, args.output_dir, optimizer, scheduler)
        validate(model, args, tb_writer, epoch)


def validate(model, args, tb_writer, iters):
    objs = AvgrageMeter()
    top1 = AvgrageMeter()
    top5 = AvgrageMeter()

    loss_function = args.loss_function

    model.eval()
    t1 = time.time()
    with torch.no_grad():
        for data, label in args.val_loader:
            data, label = data.cuda(), label.cuda()
            output = model(data)
            loss = loss_function(output, label)
            prec1, prec5 = accuracy(output, label, topk=(1, 5))
            n = data.size(0)
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

    logInfo = (
        "TEST loss = {:.6f},\t".format(objs.avg)
        + "Top-1 = {:.6f},\t".format(top1.avg)
        + "Top-5 = {:.6f},\t".format(top5.avg)
        + "val_time = {:.6f}".format(time.time() - t1)
    )
    logging.info(logInfo)
    tb_writer.add_scalar(
        'test/Loss', objs.avg, global_step=iters,
    )
    tb_writer.add_scalar(
        'test/Top-1', top1.avg, global_step=iters,
    )
    tb_writer.add_scalar(
        'test/Top-5', top5.avg, global_step=iters,
    )


if __name__ == "__main__":
    main()
