import sys; sys.path.append("..")

import os
import numpy as np
import argparse, math

import torch
from torch.optim import SGD, Adam
from data.cifar import Cifar10, Cifar100
from utility.log import Log
from model.resnet import *
from model.densenet import *
from model.vgg import *
from utility.initialize import initialize
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, MultiStepLR
from utility.utils import *
torch.cuda.empty_cache()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default='resnet18', type=str, help="select model")
    parser.add_argument("--batch_size", default=128, type=int, help="Batch size used in the training and validation loop.")
    parser.add_argument("--epochs", default=2000, type=int, help="Total number of epochs.")
    parser.add_argument("--learning_rate", '-lr', default=1e-4, type=float, help="Base learning rate at the start of the training.")
    parser.add_argument("--dataset", default="cifar10", type=str, help="dataset name")
    parser.add_argument("--threads", default=4, type=int, help="Number of CPU threads for dataloaders.")
    parser.add_argument("--seed", default=42, type=int, help="L2 weight decay.")
    parser.add_argument("--shift", default=500, type=int, help="The number of epochs to shift to stage 2.")
    parser.add_argument("--label_smoothing", default=0.1, type=float, help="label smoothing")
    parser.add_argument("--logs", default='logs', type=str, help="log save folder")
    args = parser.parse_args()

    initialize(args, seed=args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    labels = 10
    size = (32, 32)
    if args.dataset.lower() == 'cifar10':
        dataset = Cifar10(args.batch_size, args.threads, size)
    elif args.dataset.lower() == 'cifar100':
        dataset = Cifar100(args.batch_size, args.threads, size)
        labels = 100
        
    if args.model.lower() == 'resnet18':
        model = ResNet18(num_classes=labels).to(device)
    elif args.model.lower() == 'resnet34':
        model = ResNet34(num_classes=labels).to(device)
    elif args.model.lower() == 'densenet':
        model = DenseNet121(num_classes=labels).to(device)
    elif args.model.lower() == 'vgg13':
        model = VGG('VGG13',num_classes=labels).to(device) 
    elif args.model.lower() == 'vgg19':
        model = VGG('VGG19',num_classes=labels).to(device)  
    
    file_name = (args.dataset+'lr'+str(int(1e5*args.learning_rate))
                  +'model'+str(args.model)+'_hmgprior_'
                  +'seed'+str(args.seed)
                  +'shift'+str(args.shift)
                  +'batch'+str(args.batch_size)
                  +'lb'+str(args.label_smoothing)
                  )

    args.logs = args.logs+'/'+args.dataset+'/'+args.model+'/'
    log = Log(log_each=10, file_name=file_name, logs='./auto1/'+args.dataset+'/'+args.model+'/')
    criterion = torch.nn.CrossEntropyLoss(reduce=False, label_smoothing=args.label_smoothing)
    
    others = []
    w0, p, layers = initialization(model)
    b = nn.Parameter(torch.ones(1, device=device)*(torch.log(w0.abs().mean()))*1, requires_grad=True)
    opt1 = Adam(model.parameters(),lr=args.learning_rate)
    opt2 = Adam([p],lr=args.learning_rate)
    opt3 = Adam([b],lr=args.learning_rate)
    sch1 = ReduceLROnPlateau(opt1, mode='max', factor=0.1, patience=20)

    min_gamma = 0.5
    max_gamma = 10
    achieve_target_acc = 0
    prior_list, K_list = compute_K_sample(model, dataset, criterion, min_gamma, max_gamma)
    for epoch in range(args.epochs):
        model.train()
        log.train(len_dataset=len(dataset.train))
        for batch in dataset.train:
            opt1.zero_grad()
            opt2.zero_grad()
            opt3.zero_grad()

            # noise injection and ||w-w0||^2
            wdecay = weight_decay(model, w0)
            noises, noises_scaled = noise_injection(model, p)

            # loss 1
            inputs, targets = (b.to(device) for b in batch)
            predictions =  model(inputs)
            loss1 = criterion(predictions, targets).mean()

            if epoch < args.shift:
                kl = get_kl_term_with_b(wdecay, p, b)
                gamma1 = fun_K_auto(torch.exp(b),prior_list,K_list)**(-1)*( 2*(kl+60) /5e4/3 )**0.5
                gamma1 = torch.clip(gamma1,max=max_gamma,min=min_gamma)
                loss2 = 3*fun_K_auto(torch.exp(b),prior_list,K_list)**2*gamma1/2 + (kl+60)/5e4/gamma1
            else:
                loss2 = 0*loss2

            # backward
            loss1.backward(retain_graph=True)
            if epoch < args.shift:
                if epoch < 50:
                    kl_term_backward_mean(loss2, model, p, noises)
                else:
                    kl_term_backward(loss2, model, p, noises)

            # remove noises
            rm_injected_noises(model, noises_scaled)

            opt1.step()
            if epoch<args.shift:
                 opt2.step()
                 opt3.step()

            correct = torch.argmax(predictions.data, 1) == targets
            log(model, criterion(predictions, targets).cpu()+loss2.item(), correct.cpu(), 
                fun_K_auto(torch.exp(b),prior_list,K_list).detach().cpu().item())

        train_acc = log.epoch_state["accuracy"] / log.epoch_state["steps"]
        if train_acc >= 0.999 and epoch > args.shift:
            achieve_target_acc += 1
            if achieve_target_acc > 20:
                break

        # no need to keep training
        if opt1.param_groups[0]['lr'] < 1e-5:
            break

        # scheduler
        if epoch > args.shift:
            sch1.step(train_acc)

        others.append([p.mean().cpu().item()])
        # reset the non-trainable parameters of batchnorm
        with torch.no_grad():
            for batch in dataset.train:
                inputs, targets = (b.to(device) for b in batch)
                model(inputs)

        # prediction
        evaluation(model, criterion, dataset, log, 
                    fun_K_auto(torch.exp(b),prior_list,K_list).detach().cpu().item())

    log.flush()
save_model(model, w0, p, epoch, b, 
            opt1, opt2, sch1, file_name, [others, K_list], folder=args.logs)
