import numpy as np
import torch
from params import get_args
from env.env import JSP_Env
from model.REINFORCE import REINFORCE
import os

def valid(episode):
    total_ms = 0.
    for instance in os.listdir(args.valid_dir):
        file = os.path.join(args.valid_dir, instance)
        print(f'Date : {args.date} \t Episode : {episode} \tInstance: {instance}')
        avai_op = env.load_instance(file)
        while True:
            data = env.get_graph_data()
            action_idx, _, _, _ = policy(avai_op, data, greedy=True)
            avai_op, _, done = env.step(avai_op[action_idx])
            if done:
                ms = env.get_makespan()
                total_ms += ms
                break
    
    with open("./result/{}/valid_result.txt".format(args.date),"a") as outfile:
        outfile.write(f'episode : {episode}\t average_ms : {total_ms / len(os.listdir(args.valid_dir))}\n')

    return total_ms

if __name__ == '__main__':
    args = get_args()
    print(args)
    if (not os.path.exists('./weight')) or (not os.path.exists('./weight/{}'.format(args.date))):
        raise "No valid data exist"

    env = JSP_Env(args)
    policy = REINFORCE(args).to(args.device)

    best_result = 1e6
    for episode in os.listdir('./weight/{}/'.format(args.date)):
        policy.load_state_dict(torch.load('./weight/{}/{}'.format(args.date, episode), map_location=args.device),False)
        with torch.no_grad():
            valid_result = valid(episode)

        if valid_result < best_result:
            best_result = valid_result
            torch.save(policy.state_dict(), "./weight/{}/best".format(args.date, episode))
