import random
import torch
import numpy as np
import warp as wp
from tqdm import tqdm

from simulator import MPMSimulator

def simulate(data_path,config_path,output_path, gripper_type, optim_stage1_params_path, optim_stage2_params_path):
    simulator = MPMSimulator(data_path, config_path, output_path,gripper_type, False, 'cuda:0',optim_stage1_params_path)
    push_idx = 0
    simulator.init_scene(push_idx)
    simulator.init_solver(push_idx)
    if optim_stage2_params_path is not None:
        if simulator.material == 'cloth':
            simulator.load_params_cloth(optim_stage2_params_path)
        else:
            simulator.load_params(optim_stage2_params_path)
    else:
        simulator.finalize_phys()
    for frame in tqdm(range(0, simulator.end_frame)):
        if gripper_type == 'push':
            simulator.step_push()
        elif gripper_type == 'single_gripper':
            simulator.step_single_gripper()
        elif gripper_type == 'double_gripper':
            simulator.step_double_gripper()

def set_all_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

wp.init()
set_all_seeds(42)

if __name__ == '__main__':

    name = 'single_push_rope'
    gripper_type = 'push'

    config_path = f'./config/{name}_config.json'
    data_path = f'./data/phystwin_data/{name}'
    output_path = './output/simulation'

    # optimized_stage1_params_path = f'./optim_stage1/{name}/optim_physics.pkl'
    optimized_stage1_params_path = None

    # optimized_stage2_params_path = f'./optim_stage2/{name}/best_E.pkl'
    optimized_stage2_params_path = None

    with torch.no_grad():
        simulate(data_path, config_path, output_path, gripper_type, optimized_stage1_params_path,optimized_stage2_params_path)