import torch
import json
import logging

from data import init_graph_for_sim
from data import cosserat_rod_mixed
from net import graph_tuple
from net import encode_process_decode
from my_config import config as cfg

logger = logging.getLogger(__name__)


def load_model(init_graph, trained_file):
    # get simulation dimension
    out_dim = 6

    TEST_NODES_SHAPE = (init_graph.NODES_SHAPE[0], 8)
    TEST_EDGES_SHAPE = (init_graph.EDGES_SHAPE[0], 5)
    TEST_GLOBAL_SHAPE = init_graph.GLOBALS_SHAPE
    OUTPUT_SHAPE = out_dim

    # Create the network
    model = encode_process_decode.EncodeDecode(TEST_NODES_SHAPE, TEST_EDGES_SHAPE, TEST_GLOBAL_SHAPE, OUTPUT_SHAPE, cfg)

    # Loading model parameters from trained result, do not use logger
    logger.info("Loading model from: {}".format(trained_file))
    trained_model = torch.load(trained_file)
    model.load_state_dict(trained_model["state_dict"])

    if cfg.flag_gpu:
        model = model.cuda()
    return model


def test_main(norm_file_name, trained_file, flag_dump_gif,
              flag_dump_json, flag_save_png, flag_to_end, data_type, flag_plot_loss,
              flag_play_sim, flag_play_mixed, flag_net_count, flag_forced_length_cons):
    logger.info("norm file name: {}".format(norm_file_name))
    with open(norm_file_name, 'r') as inputfile:
        data_norm = json.load(inputfile)
    nodes_norm = data_norm["max_corr_x"]
    edges_norm = data_norm["max_lambda_v"]

    if data_type == "swing":
        num_nodes = 30
        l = 4.0
        angle = 30.0
        Young_M = 1.0e5
        position_norm, max_E_G, split_list, graph = init_graph_for_sim.init_bend_rod_graph(num_nodes, l, angle, Young_M)

    elif data_type == "helix":
        num_nodes = 60
        HR = 0.5
        HH = 0.5
        HW = 2.5
        Torsion_M = 1.0e5
        position_norm, max_E_G, split_list, graph = init_graph_for_sim.init_helix_spring_graph(num_nodes, HR, HH, HW, Torsion_M)
    else:
        raise Exception("data type not correct")

    graph = graph_tuple.GraphTuple([graph])
    #print(graph)
    step_num = 100
    step_size = 0.02

    if data_type == "swing":
        logger.info("data type swing")
        x_range = [-2, 5.0]
        y_range = [-2, 5.0]
        z_range = [-4.5, 1.5]
    elif data_type == "helix":
        logger.info("data type helix")
        x_range = [-1, 1]
        y_range = [-1, 1]
        z_range = [-2, 0]
    else:
        raise Exception("data type should be 'swing', 'helix' or 'tree' ")

    # load model from trained parameter
    model = load_model(graph, trained_file)

    # physical simulator
    if flag_play_sim:
        logger.info("Playing physical simulation")
        phy_rollout_list, iter_record, num_cg = cosserat_rod_mixed.roll_out_physics_gc_from_net_solver(model, graph,
                                                                                    position_norm, max_E_G, split_list,
                                                                                    step_num, step_size, nodes_norm,
                                                                                    edges_norm, 0.0)


    # mixed simulator
    if flag_play_mixed:
        logger.info("Playing mix simulation")
        mixed_rollout_list, iter_record, num_cg = cosserat_rod_mixed.roll_out_physics_gc_from_net_solver(model, graph,
                                                                                    position_norm, max_E_G, split_list,
                                                                                    step_num, step_size, nodes_norm,
                                                                                    edges_norm, 1.0)


if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(message)s', handlers=[logging.StreamHandler()])
    flag_dump_gif = False
    flag_dump_json = False
    flag_save_png = False
    flag_to_end = False
    flag_plot_loss = False

    flag_play_sim = True
    flag_play_mixed = True

    flag_net_count = False
    flag_forced_length_cons = False
    cfg.flag_normed = False
    cfg.flag_gpu = False

    data_type = "swing"
    norm_file_name = "dataset/cosserat_test_normalized10.json"

    cfg.latent_size = 32  #
    cfg.num_layers = 2 # default 3, change to 2 for case with GN iterations
    cfg.gn_iteration = 3  # number of gn iterations

    trained_file = "out/rod/best_val_model.pth.tar"

    test_main(norm_file_name, trained_file, flag_dump_gif, flag_dump_json,
              flag_save_png, flag_to_end, data_type, flag_plot_loss, flag_play_sim, flag_play_mixed,
              flag_net_count, flag_forced_length_cons)
