import argparse
import os
import shutil
import json
from tqdm import tqdm
import numpy as np

import torch
from torch import optim

from nn import LinearSGM
from misc import fix_randomness
from data_io import load_data


parser = argparse.ArgumentParser()

parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--data_dim', type=int, required=True)
parser.add_argument('--save_dir', type=str, required=True)
parser.add_argument('--batch_size', type=int, default=500)
parser.add_argument('--random_state', type=int, default=123)
parser.add_argument('--num_epochs', type=int, default=100)

parser.add_argument('--learning_rate', type=float, default=0.01)
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--coeff_epsilon', type=float, default=0.5)
parser.add_argument('--num_slices', type=float, default=1)
parser.add_argument('--grad_norm', type=float, default=5.0)

args = parser.parse_args()
print(json.dumps(args.__dict__, indent=True, ensure_ascii=False), end='\n\n', flush=True)

fix_randomness(args.random_state)

batch_iter = load_data(args.data_dir, args.batch_size)
model = LinearSGM(args.data_dim, args.hidden_dim, args.data_dim)
if torch.cuda.is_available():
    model = model.cuda()
    data_support = data_support.cuda()
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

if os.path.exists(args.save_dir):
    shutil.rmtree(args.save_dir)
os.makedirs(args.save_dir)

lowest_loss = np.inf
for epoch_id in range(0, args.num_epochs):
    cumu_fl, cumu_sl, cumu_hl, step_id = 0.0, 0.0, 0.0, 0
    
    for samples in tqdm(batch_iter, ncols=70, desc='epoch ' + str(epoch_id)):
        if torch.cuda.is_available():
            samples = samples.cuda()
            samples.requires_grad = True
        
        optimizer.zero_grad()
        field_loss, scalar_loss = model.loss_fn(samples, args.coeff_epsilon, args.num_slices)
        hybrid_loss = (field_loss + scalar_loss)
        hybrid_loss.backward()
        optimizer.step()
        
        cumu_fl += field_loss.item()
        cumu_sl += scalar_loss.item()
        cumu_hl += hybrid_loss.item()
        print('step:', step_id, '| field loss:', field_loss.item(), '| scalar loss:', scalar_loss.item(), flush=True)
        step_id += 1
        
    print('epoch:', epoch_id, '| field loss:', cumu_fl, '| scalar loss:', cumu_sl, flush=True)
    if lowest_loss > hybrid_loss:
        lowest_loss = hybrid_loss
        torch.save(model, os.path.join(args.save_dir, 'model_epoch' + str(epoch_id) + '.pt'))
        print('model saved at epoch', str(epoch_id), flush=True)
    print(end='\n\n', flush=True)
        
torch.save(model, os.path.join(args.save_dir, 'model_final.pt'))
print('final model saved', flush=True)
