import os
import sys
matnet_base_path = os.path.abspath('/data1/XXX/MatNet/FFSP/FFSP_MatNet')
sys.path.append(matnet_base_path)
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
sys.path.append(base_path)

import pickle
import numpy as np
import torch

from environment.used.BaseEnv_COP import DataProblem, RawData
from matnet_test import eval_main, env_params, tester_params

os.environ['CUDA_VISIBLE_DEVICES'] = '7'

to_data_dir = '/data1/XXX/gato-revise/toy-gato-data/data/used/FFSP_V2/'

data_num_dict = {
        'train': 1000,
        'problem': 100,
}

stage_cnt = env_params['stage_cnt']
machine_cnt_list = env_params['machine_cnt_list']
process_time_params = env_params['process_time_params']
job_cnt = env_params['job_cnt']
batch_size = 100
total_machine_cnt = sum(machine_cnt_list)

def get_random_problems(batch_size, stage_cnt, machine_cnt_list, job_cnt, process_time_params):

    time_low = process_time_params['time_low']
    time_high = process_time_params['time_high']

    problems_INT_list = []
    for stage_num in range(stage_cnt):
        machine_cnt = machine_cnt_list[stage_num]
        stage_problems_INT = torch.randint(low=time_low, high=time_high, size=(batch_size, job_cnt, machine_cnt))
        problems_INT_list.append(stage_problems_INT)

    return problems_INT_list




#####################################################   
### generate problem.pkl
total_trajs = []
total_problems = []
for _ in range(data_num_dict['problem']//(batch_size)):
    problems_INT_list = get_random_problems(batch_size, stage_cnt, machine_cnt_list, job_cnt, process_time_params)
    trajs, returns = eval_main(problems_INT_list) # list of dict, 24倍
    problems = torch.cat(problems_INT_list, dim=2).cpu().numpy()  # (batch, job, total_machine)
    problems = np.concatenate((problems, np.zeros((batch_size, 1, total_machine_cnt))), axis=1)  # (batch, job+1, total_machine)
    problems = list(problems) #  batch *24 * (job+1, machine_cnt * 3)
    total_problems+=problems
    total_trajs += trajs
problem_list = []
answer_list = []
for i in range(0, data_num_dict['problem']):
    problem_list.append({
        'durations': total_problems[i].astype(np.int32),  # (job+1, 12)
    })
    answer_list.append(list(total_trajs[i]['action']))  # (steps, action dim)
problem_data = DataProblem(prefix_list=None, problem_list=problem_list, answer_list=answer_list)


#####################################################
### generate train.pkl

total_trajs = []
total_problems = []
for _ in range(data_num_dict['train']//batch_size):
    problems_INT_list = get_random_problems(batch_size, stage_cnt, machine_cnt_list, job_cnt, process_time_params)
    trajs, returns = eval_main(problems_INT_list) # list of dict
    problems = torch.cat(problems_INT_list, dim=2).cpu().numpy()  # (batch, job, total_machine_cnt)
    problems = np.concatenate((problems, np.zeros((batch_size, 1, total_machine_cnt))), axis=1)  # (batch, job+1, total_machine_cnt)
    problems = list(problems) #  batch * (job+1, total_machine_cnt)
    total_problems+=problems
    total_trajs += trajs
train_data = []
for i in range(data_num_dict['train']):
    rewards = np.zeros_like(total_trajs[i]['action'])
    rewards[-1] = 0
    terminals = np.zeros_like(total_trajs[i]['action'], dtype=bool)
    terminals[-1] = True
    train_data.append({
        'prefix': {'durations': total_problems[i].reshape(-1).astype(np.int32)},  #(job+1, machine_cnt) -> (job+1) * total_machine_cnt
        'prefix_masks': {'durations': total_trajs[i]['prefix_mask'].reshape(-1, 12, 21*12)},  #(steps, act_num, job+1, machine) -> (steps, act_num, (job*1) * total_machine_cnt) 对不同的job以及不同阶段的machine做mask，
        'observations': {
            'machine_wait_time': total_trajs[i]['machine_wait_time'].astype(np.int32),  #(steps, total_machine_cnt)
            'time_idx': total_trajs[i]['time_idx'][:, None].astype(np.int32)           #(steps, 1)
            },  
        'actions': total_trajs[i]['action'],                                    #(steps, act_num)
        'rewards': rewards,
        'terminals': terminals
    })



with open(os.path.join(to_data_dir, f'ffsp{job_cnt}_train-1000.pkl'), 'wb') as f:
    pickle.dump(train_data, f)
    print(f'train_data saved')

with open(os.path.join(to_data_dir, f'ffsp{job_cnt}_problem-100.pkl'), 'wb') as f:
    pickle.dump(problem_data, f)
    print(f'problem saved')

#####################################################
### generate train_problem.pkl
problem_list = []
answer_list = []
if os.path.exists(os.path.join(to_data_dir, f'ffsp{job_cnt}_train.pkl')):
    print('generating train_problem.pkl from train.pkl')
    with open(os.path.join(to_data_dir, f'ffsp{job_cnt}_train.pkl'), 'rb') as f:
        train_data = pickle.load(f)
        for i in range(0, 500):
            problem_list.append({'durations':train_data[i]['prefix']['durations'].reshape(21, 12)}) #(job+1) * total_machine_cnt -> (job+1, machine_cnt)
            answer_list.append(list(train_data[i]['actions']))
    problem_data = DataProblem(prefix_list=None, problem_list=problem_list, answer_list=answer_list)
    with open(os.path.join(to_data_dir, f'ffsp{job_cnt}_train_problem.pkl'), 'wb') as f:
        pickle.dump(problem_data, f)
        print(f'train_problem data saved')
    # exit()