#!/usr/bin/env python
# coding: utf-8

"""
import torch
torch.set_num_threads(7)
torch.set_num_interop_threads(7)
torch.backends.cudnn.benchmark = True
#"""
from utility import *
# ------------------------------------------------------
import argparse
import datetime
def get_config():
    args = argparse.ArgumentParser()
    # general
    args.add_argument('--core',           default=0, type=int)
    args.add_argument('--model',          default='MLP', type=str)
    args.add_argument('--data',           default='AVILA2', type=str)
    # optim 
    args.add_argument('--batch_size',     default=10430, type=int)
    args.add_argument('--epoch',     default=10000, type=int)
    args.add_argument('--sharpness', default=1, type=float)
    args.add_argument('--lr',             default=1e-2, type=float)
    args.add_argument('--loss', default="MSE", type=str)
    return args.parse_args()

# -------------------------------------------------------------------
import torch
import math
from sharpness.tools import Hessian_trace, Hessian_diag 
import sharpness.Minimum as Minimum
import os
import copy
import numpy as np
def optimize(dataset, model, config):
    loss_func = get_loss(config)
    optimizer = torch.optim.SGD(model.parameters(), lr=config.lr, weight_decay=0.1, momentum=0.0)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[500,1000,2000,3000,4000], gamma=0.5)

    def update():
        model.train()
        measure = {'loss':0,'accuracy':0}
        index = torch.randperm(dataset.train.n)
        for idx in torch.split(index, config.batch_size):
            optimizer.zero_grad()
            x = dataset.train.x[idx]
            y = dataset.train.y[idx]
            o = model(x)
            loss = loss_func(o,y)
            loss.backward()            
        optimizer.step()
        scheduler.step()

    def evaluate(data):
        model.eval()
        with torch.no_grad():
            output    = model(data.x)            # logit
            loss    = loss_func(output, data.y)
            loss     = loss.item()
            output    = output.max(dim=1)[1]            # logit -> index
            correct    = (output==data.y) 
            accuracy   = correct.float().mean().item()    # bool -> int(0,1) -> float
            return loss, accuracy

    epoch = 0
    while True:
        epoch += 1
        report(f'epoch:{epoch}')
        #
        saved_model = copy.deepcopy(model)
        update()
        def l2_deviation():
            l2_distance = 0 
            for origin_param, perturbed_param in zip(saved_model.parameters(), model.parameters()) :
                deviation = perturbed_param.data - origin_param.data
                l2_distance += torch.norm(deviation, 2).item()
            return l2_distance
        #
        status = {}
        train_loss, train_acc = evaluate(dataset.train)
        status['train'] = {'loss': train_loss, 'accuracy':  train_acc}
        test_loss, test_acc = evaluate(dataset.test)
        status['test'] = {'loss': test_loss, 'accuracy':  test_acc}
        # Dump Log
        for mode in ['train','test']:
            message     = [f'\t{mode:5}']
            message    += [f"loss:{status[mode]['loss']: 18.7f}"]
            message    += [f"accuracy:{status[mode]['accuracy']: 9.7f}"]
            report(*message)
        assert not math.isnan(status['train']['loss']), 'find nan in train-loss'
        assert not math.isnan(status['test']['loss']),  'find nan in test-loss'
        if epoch == config.epoch: break


# -------------------------------------------------------------------    
import os
import time
import json
import shutil
def main():
    config = get_config()
    set_device(config)
    dataset = get_dataset(config)
    model = get_model(config)

    optimize(dataset, model, config)
    torch.save(model.state_dict(), f"out/Trained.weight")
if __name__=='__main__':
    main()

# -------------------------------------------------------------------

