import torch
import time
import os
import json
import logging
from shutil import copyfile

from my_config import config as cfg
from data import cosserat_rod_data
from net import encode_process_decode
from net import loss_con

SEED = 0
torch.manual_seed(SEED)

# creating output dir
out_dir = cfg.out_dir
out_dir += time.strftime('%Y-%m-%d-%H-%M-%S')
os.mkdir(out_dir)

# Adding logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
file_handler = logging.FileHandler(out_dir+"/log.txt")
file_handler.setLevel(logging.INFO)
logger.addHandler(stream_handler)
logger.addHandler(file_handler)
logger.info("\nCreating output dir {} and log files".format(out_dir))
copyfile("./my_config.py", os.path.join(out_dir, 'my_config.py'))

# get simulation dimension
out_dim = 6

cfg.train_file_name = "dataset/train_file.json"
cfg.val_file_name = "dataset/val_file.json"
# load train dataset
logger.info("\nTrain dataset info:")
train_dataset = cosserat_rod_data.CosseratRodCorrectionDataSet(cfg.train_file_name)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=cfg.tr_batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=True
    )
logger.info("Train file path"+cfg.train_file_name)

# load validation dataset
logger.info("\nValidation dataset info:")
val_dataset = cosserat_rod_data.CosseratRodCorrectionDataSet(cfg.val_file_name)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=cfg.val_batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=True
    )
logger.info("Val file path"+cfg.val_file_name)

# get shape parameter from dataset
TR_NODES_SHAPE = train_dataset.NODES_SHAPE
TR_EDGES_SHAPE = train_dataset.EDGES_SHAPE
TR_GLOBAL_SHAPE = train_dataset.GLOBAL_SHAPE
OUTPUT_SHAPE = out_dim


# verify nodes dimension
#print(train_dataset.NODES_SHAPE[-1], cfg.node_dim)
assert train_dataset.NODES_SHAPE[-1] == cfg.node_dim or train_dataset.NODES_SHAPE[-1] == cfg.node_dim_corr, \
    "trainset nodes dimension {} not matching".format(train_dataset.NODES_SHAPE[-1])
assert val_dataset.NODES_SHAPE[-1] == cfg.node_dim or val_dataset.NODES_SHAPE[-1] == cfg.node_dim_corr, \
    "valset nodes dimension {} not matching".format(val_dataset.NODES_SHAPE[-1] )

cfg.latent_size = 32  #
cfg.num_layers = 2  # default 3, change to 2 for case with GN iterations
cfg.gn_iteration = 3  # number of gn iterations

# setup the model
# only the second dimension is needed
model = encode_process_decode.EncodeDecode(TR_NODES_SHAPE, TR_EDGES_SHAPE, TR_GLOBAL_SHAPE, OUTPUT_SHAPE, cfg)

if cfg.flag_from_trained:
    logger.info('\nLoading model from {} as initial model.'.format(cfg.trained_file))
    trained_model = torch.load(cfg.trained_file)
    model.load_state_dict(trained_model["state_dict"])

loss_func_corr = loss_con.loss_func_corr()
torch.cuda.set_device(cfg.gpus)

if cfg.flag_gpu:
    ##model = torch.nn.DataParallel(model, device_ids=gpus).cuda()
    model = model.cuda()
    loss_func_corr.cuda()

optim = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optim, milestones=cfg.milestones, gamma=cfg.gamma)


# Set up parameters
num_time_steps = train_dataset.__len__()

log_every_seconds = 20
start_time = time.time()
last_log_time = start_time
best_tr_loss = 10000
best_val_loss = 10000

logger.info("\nModel summary:")
logger.info(model)
logger.info("Start looping......")

tr_losses = []
val_losses = []

tr_time = 0
val_time = 0

for t in range(cfg.num_iters):
    # logger.info("\nEpoch number: {}".format(t))
    # Trainging
    model.train()
    tr_loss = 0
    lr_scheduler.step()

    tr_time_start = time.time()
    for i, (graph, corr_x_gt, lambda_v_gt) in enumerate(train_loader):
        #print("\nBatch num: {}".format(i+1))
        globals = graph["globals"].float()
        edges = graph["edges"].float()
        nodes = graph["nodes"].float()
        corr_x_gt = corr_x_gt.float()
        lambda_v_gt = lambda_v_gt.float()

        senders = graph["senders"].type(torch.LongTensor)
        receivers = graph["receivers"].type(torch.LongTensor)
        nodes_index = graph["nodes_index"].type(torch.LongTensor)
        edges_index = graph["edges_index"].type(torch.LongTensor)

        if cfg.flag_gpu:
            globals = globals.cuda()
            edges = edges.cuda()
            nodes = nodes.cuda()
            corr_x_gt = corr_x_gt.cuda()
            lambda_v_gt = lambda_v_gt.cuda()

            senders = senders.cuda()
            receivers = receivers.cuda()
            nodes_index = nodes_index.cuda()
            edges_index = edges_index.cuda()

        output_corr, output_lambda_v, _ = model(nodes, edges, globals, senders, receivers, nodes_index, edges_index)
        loss = loss_func_corr(corr_x_gt, output_corr, lambda_v_gt, output_lambda_v)

        optim.zero_grad()
        loss.backward()
        optim.step()

        tr_loss += loss.item()*corr_x_gt.shape[0]
        #print(i, tr_loss)

    tr_time += time.time() - tr_time_start
    # Validating
    model.eval()
    val_loss = 0

    val_start_time = time.time()
    for i, (graph, corr_x_gt, lambda_v_gt) in enumerate(val_loader):
        # print("\nBatch num: {}".format(i+1))
        globals = graph["globals"].float()
        edges = graph["edges"].float()
        nodes = graph["nodes"].float()
        corr_x_gt = corr_x_gt.float()
        lambda_v_gt = lambda_v_gt.float()

        senders = graph["senders"].type(torch.LongTensor)
        receivers = graph["receivers"].type(torch.LongTensor)
        nodes_index = graph["nodes_index"].type(torch.LongTensor)
        edges_index = graph["edges_index"].type(torch.LongTensor)

        if cfg.flag_gpu:
            globals = globals.cuda()
            edges = edges.cuda()
            nodes = nodes.cuda()
            corr_x_gt = corr_x_gt.cuda()
            lambda_v_gt = lambda_v_gt.cuda()

            senders = senders.cuda()
            receivers = receivers.cuda()
            nodes_index = nodes_index.cuda()
            edges_index = edges_index.cuda()

        output_corr, output_lambda_v, _ = model(nodes, edges, globals, senders, receivers, nodes_index, edges_index)
        #print("gt", corr_x_gt, "pred: ", output_corr)
        loss = loss_func_corr(corr_x_gt, output_corr, lambda_v_gt, output_lambda_v)

        val_loss += loss.item() * corr_x_gt.shape[0]

    # Append the train loss and validating loss
    tr_loss /= num_time_steps
    tr_losses.append(tr_loss)
    val_loss /= num_time_steps
    val_losses.append(val_loss)

    #if flag_gpu:
    #    out_dict = model.module.state_dict()
    #else:
    out_dict = model.state_dict()
    val_time += time.time() - val_start_time


    if tr_loss < best_tr_loss:
        logger.info('\nEpoch: {}, updating best training loss model'.format(t))
        logger.info("Run time: {}".format(time.time() - start_time))
        logger.info("Train Loss: {}".format(tr_loss))
        best_tr_loss = tr_loss
        torch.save({'epoch': t, 'state_dict': out_dict, 'tr_loss': tr_loss,
                    'optimizer': optim.state_dict()}, os.path.join(out_dir, 'best_tr_model.pth.tar'))
        copyfile(os.path.join(out_dir, 'best_tr_model.pth.tar'), os.path.join('./out', 'best_tr_model.pth.tar'))

    if val_loss < best_val_loss:
        logger.info('\nEpoch: {}, updating best validation loss model'.format(t))
        logger.info("Run time: {}".format(time.time() - start_time))
        logger.info("Val Loss: {}".format(val_loss))
        best_val_loss = val_loss
        torch.save({'epoch': t, 'state_dict': out_dict, 'val_loss': val_loss,
                    'optimizer': optim.state_dict()}, os.path.join(out_dir, 'best_val_model.pth.tar'))
        copyfile(os.path.join(out_dir, 'best_val_model.pth.tar'), os.path.join('./out', 'best_val_model.pth.tar'))

    elapsed_since_last_log = time.time() - last_log_time
    if elapsed_since_last_log > log_every_seconds:
        last_log_time = time.time()
        logger.info('\nTime: {}'.format(time.strftime('%Y-%m-%d-%H:%M:%S')))
        logger.info('Epoch: {}, updating checkpoint model'.format(t))
        logger.info("Run time: {}".format(time.time() - start_time))
        logger.info("Train Loss: {}".format(tr_loss))
        logger.info("Val Loss: {}".format(val_loss))
        logger.info("Used train time: {}".format(tr_time))
        logger.info("Used val time: {}".format(val_time))
        tr_time = 0
        val_time = 0
        torch.save({
            'epoch': t,
            'state_dict': out_dict,
            'tr_loss': tr_loss,
            'val_loss': val_loss,
            'optimizer': optim.state_dict()},
            os.path.join(out_dir, 'checkpoint_model.pth.tar'))
        copyfile(os.path.join(out_dir, 'checkpoint_model.pth.tar'), os.path.join('./out', 'checkpoint_model.pth.tar'))

        with open(os.path.join(out_dir, "loss.json"), 'w') as outputfile:
            json.dump({"tr_losses": tr_losses, "val_losses": val_losses},
                      outputfile)
        copyfile(os.path.join(out_dir, "loss.json"), os.path.join('./out', "loss.json"))


