import os
import json
from config import Config

# create experiment config containing all hyperparameters
cfg = Config()

from models import get_model
from helpers import *


def main_deformation_test():
    # create network and training agent
    model = get_model(cfg)
    
    # run
    if cfg.stage == 'init':
        result_dir = os.path.join(cfg.exp_dir, f'results/init_deformation_{cfg.tag}')
        ensure_dir(result_dir)
        with open(os.path.join(result_dir, 'config.json'), 'w') as f:
            json.dump(cfg.__dict__, f, indent=2)

        # initialization
        model.initialize()
        fig = model.draw_field(cfg.vis_resolution, attr='deformation')
        save_path = os.path.join(result_dir, f'init.png')
        plt.savefig(save_path)
        plt.close(fig)

    elif cfg.stage == 'simulate':
        result_dir = os.path.join(cfg.exp_dir, f'results/step_deformation_{cfg.tag}')
        ensure_dir(result_dir)
        with open(os.path.join(result_dir, 'config.json'), 'w') as f:
            json.dump(cfg.__dict__, f, indent=2)

        # NOTE: reset optimizer
        model.create_optimizer(use_scheduler=True)

        # load trained network with source fitted, t=0
        try:
            model.load_ckpt('initialize')
            print("load initial trained model.")
        except Exception as e:
            model.initialize()

        pointcloud_result_dir = os.path.join(cfg.exp_dir, f'results_pointcloud/step_deformation_{cfg.tag}')
        ensure_dir(pointcloud_result_dir)
        pointcloud_save_path = os.path.join(pointcloud_result_dir, f'stepV_t{model.timestep:03d}.ply')
        
        if cfg.write_pointcloud == 1:
            fig = model.draw_field(cfg.vis_resolution, attr='deformation', output_filename=pointcloud_save_path)
        else:
            fig = model.draw_field(cfg.vis_resolution, attr='deformation')
        
        save_path = os.path.join(result_dir, f'stepV_t{model.timestep:03d}.png')
        plt.savefig(save_path)
        plt.close(fig)

        model.save_ckpt(f'step_deformation_{model.cfg.tag}_0')

        for t in range(cfg.n_timesteps):
            model.timestep += 1

            model.step()

            pointcloud_save_path = os.path.join(pointcloud_result_dir, f'stepV_t{model.timestep:03d}.ply')

            if cfg.write_pointcloud == 1:
                fig = model.draw_field(cfg.vis_resolution, attr='deformation', output_filename=pointcloud_save_path)
            else:
                fig = model.draw_field(cfg.vis_resolution, attr='deformation')
            save_path = os.path.join(result_dir, f'stepV_t{model.timestep:03d}.png')
            plt.savefig(save_path)
            plt.close(fig)

        save_path = os.path.join(result_dir, 'step_anim.gif')
        frames2gif(result_dir, save_path, fps=20 * 0.1 / cfg.dt)


if __name__ == '__main__':
    main_deformation_test()
