import os

import random
import torch
import pickle
import numpy as np
import warp as wp
from tqdm import tqdm
from simulator import MPMSimulator

def evaluate(data_path, config_path, output_path, evaluate_path, gripper_type, optim_stage1_params_path, optim_stage2_params_path):
    simulator = MPMSimulator(data_path, config_path, output_path,gripper_type, False, 'cuda:0',optim_stage1_params_path)
    push_idx = 0
    simulator.init_scene(push_idx)
    simulator.init_solver(push_idx)
    if optim_stage2_params_path is not None:
        if simulator.material == 'cloth':
            simulator.load_params_cloth(optim_stage2_params_path)
        else:
            simulator.load_params(optim_stage2_params_path)
    else:
        simulator.finalize_phys()
    # simulator.pre_move_save_state(push_idx)
    all_pos = []
    pos = simulator.extract_structure_points()
    all_pos.append(pos)
    for frame in tqdm(range(simulator.start_frame, simulator.end_frame)):
        if gripper_type == 'push':
            simulator.step_push()
        elif gripper_type == 'single_gripper':
            simulator.step_single_gripper()
        elif gripper_type == 'double_gripper':
            simulator.step_double_gripper()
        simulator.get_current_state(save_path=evaluate_path, epoch=0, draw_gt=True)
        pos = simulator.extract_structure_points()
        all_pos.append(pos)
    os.makedirs(evaluate_path, exist_ok=True)
    save_points_to_pkl(evaluate_path,all_pos)

def save_points_to_pkl(save_path,data):
    data_array = np.array(data)
    pkl_path = os.path.join(save_path,'inference.pkl')
    with open(pkl_path, 'wb') as f:
        pickle.dump(data_array, f)
    print(f"pkl save to {pkl_path}")

def set_all_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

wp.init()
set_all_seeds(42)


if __name__ == '__main__':

    name = 'single_push_rope'
    gripper_type = 'push'

    config_path = f'./config/{name}_config.json'
    model_path = f'./data/phystwin_data/{name}/gaussian_output'
    data_path = f'./data/phystwin_data/{name}'
    output_path = f'./output/evaluate/{name}'
    evaluate_path = f'./evaluate/{name}'

    # optimized_stage1_params_path = f'./optim_stage1/{name}/optim_physics.pkl'
    optimized_stage1_params_path = None

    # optimized_stage2_params_path = f'./optim_stage2/{name}/best_params.pkl'
    optimized_stage2_params_path = None

    with torch.no_grad():
        evaluate(data_path, config_path, output_path,evaluate_path, gripper_type, optimized_stage1_params_path,optimized_stage2_params_path)
