from utils import *
from planner import Planner
from patroller import Patroller
from memory import Memory
from performer import Performer
from percipient import Percipient
# from worldmodel import WorldModel # use the world model in buffer
from buffer import Buffer
from ruleminer import RuleMiner
from minedojo.sim import InventoryItem
import minedojo
import random
import argparse
import numpy as np
import json
import datetime
from reasoner_mc import *


task_dir = '/home/**/Workspace/MP5/MP5_agent/agent/tasks/task_techtree_testset.json'
running_dataset = 'TEST'
# task_dir = '/home/**/Workspace/MP5/MP5_agent/agent/tasks/task_techtree_trainingset.json'
# running_dataset = 'TRAIN'


# Training
every_task_max_retries = 50 # train 5; test 10
every_task_max_planning_retries = 40 # train None(no world model); test 300 



# ID=27 # 0, ..., 27
ID = None
testModel = f'wm-reflex-norule-noprior'

interval = -1
#############################################
#############################################
roundtimes = 0 # every round 30 tasks
initialtaskID = 0
# rules_dir = f'/home/**/Workspace/MP5/MP5_agent/agent/buffer_rules/rules_library_{ID}.json'
rules_dir = None
# rule_code_file = f'/home/**/Workspace/MP5/MP5_agent/agent/rules_codeset/rules_codeset_{ID}-selected.json'
rule_code_file = None
# traj_memory_dir = '/home/**/Workspace/MP5/MP5_agent/agent/buffer_traj_memory/combined_data_batch_0.json'
traj_memory_dir = None


#############################################


model_name= 'gpt-4o' # "gpt-3.5-turbo", "gpt-4o" gpt-4o-mini
memory = Memory(use_history_workflow=False)
patroller = Patroller(memory=memory, model_name=model_name)
# temperature=0.3
# ruleminer = RuleMiner(rules_dir = rules_dir, model_name=model_name, temperature=temperature)
temperature=0.7
choice_num = 3
ruleminer = RuleMiner(rules_dir = rules_dir, model_name=model_name, temperature=temperature, choice_num = choice_num)
percipient = "TBD" # TODO


##########################################################
## world model setting
#################
# wm with rules
worldmodel_mode = 'rules'

# # wm with transitions
# worldmodel_mode = 'ref'

# world model setting
model_name= 'gpt-4o'  # gpt-4o-mini gpt-4-turbo
# TODO attention!!! do not add language rules library for BUFFER
buffer = Buffer(memory = memory, worldmodel_mode = worldmodel_mode, rules_dir = None, rule_code_file = rule_code_file, model_name = model_name)
##########################################################
## world model setting
##########################################################


##########################################################
## alg setting
##########################################################

# planner_mode = 1 # mode = 0 ['agent+world model']; 1 ['agent with rules']; 2 ['agent with rules and state']
# planner_temperature=0 
# planner_choice_num=1 
# planner_search_alg = 'original__reflex'
# module_trajmemory = False

# planner_mode = 1 # mode = 0 ['agent+world model']; 1 ['agent with rules']; 2 ['agent with rules and state']
# planner_temperature=0 
# planner_choice_num=1 
# planner_search_alg = 'original__reflex'
# module_trajmemory = False

# # # normal with reflexion & memory
# planner_mode = 0 # mode = 0 ['agent+world model']; 1 ['agent with rules']; 2 ['agent with rules and state']
# planner_temperature=0 
# planner_choice_num=1 
# planner_search_alg = 'original__reflex_memory'
# module_trajmemory = True
# #$ planner.py___reference_plan_string
# #$ memory.py___memory.py
# #$ memory = Memory(use_history_workflow=True)

# # MCTS
# planner_mode = None # mode = 0 ['agent+world model']; 1 ['agent with rules']; 2 ['agent with rules and state']
# planner_temperature=0 # 0.5
# planner_choice_num=1 # 2
# planner_search_alg = 'MCTS'
# module_trajmemory = False

# # MPC # candidate routes=2
# planner_mode = None # mode = 0 ['agent+world model']; 1 ['agent with rules']; 2 ['agent with rules and state']
# planner_temperature=0.5 # 0.5
# planner_choice_num=2 # 2
# planner_search_alg = 'MPC'
# module_trajmemory = False

# MPC # candidate routes=1
planner_mode = None # mode = 0 ['agent+world model']; 1 ['agent with rules']; 2 ['agent with rules and state']
planner_temperature=0.5 # 0.5
planner_choice_num=1 # 2
planner_search_alg = 'MPC'
module_trajmemory = False

## Planner Setting
# initial_epsilon = 0.9  # Start with high exploration
# decay = 0.97  # Epsilon decays by 1% each episode
# minimum_epsilon = 0.1  # Minimum epsilon value


model_name= 'gpt-4o' # "gpt-3.5-turbo", "gpt-4o", gpt-4o-mini, gpt-4-turbo
planner = Planner(memory=memory, mode = planner_mode, rules_dir = rules_dir, traj_memory_dir=traj_memory_dir, model_name=model_name, temperature=planner_temperature, choice_num=planner_choice_num)
##########################################################
## alg setting
##########################################################


with open(f'/home/**/Workspace/MP5/MP5_agent/agent/task_result/[tokenuse]_{running_dataset}_{testModel}.log', 'a') as f:
    f.write(f'\n\n***** MODEL: {testModel}, Dataset: {running_dataset} *****') 


f_mkdir(f"../images"); f_remove("../video"); f_mkdir(f"../video")
f_mkdir(f"../logs"); logging.basicConfig(filename=f'../logs/agent.log', filemode='w', level=logging.INFO, format='%(message)s')


def MPC_mc(planner, buffer, check_result, initial_state, task_information, every_task_max_planning_retries):
    act_seq_final = []

    count = 0
    while count <= every_task_max_planning_retries:
        # !! planning
        log_info(f"[Initial state]: {initial_state}")
        multi_workflow, _ = planner.get_workflow_for_search(task_information=task_information, check_result=check_result, initial_state = initial_state, running_dataset = running_dataset, testModel = testModel)
        # !! world model prediction
        action_seq, check_result, final_state_output = buffer.wm_prediction_with_multiple_action_seqs(initial_state, multi_workflow, running_dataset, testModel)
        count += 1
        initial_state = final_state_output
        act_seq_final.extend(action_seq)
        # Fail halfway through
        if not check_result["success"]:
            log_info(f"Failure predicted by World Model: {check_result}")
            count += 1
            continue
        if check_result["success"]:
            log_info(f"[World Model]: Task:{task_information}, Success Planning in World Model: {check_result}")
            workflow_dict = {"workflow": act_seq_final}
            return workflow_dict
        log_info(f"----World Model]: Task:{task_information}, world model planning times:{count}, Failure predicted by World Model: {check_result}")
    log_info(f"[World Model]: Task:{task_information}, Failure predicted by World Model: {check_result}")
    
    workflow_dict = {"workflow": act_seq_final}
    return workflow_dict

# without world model
def Normal_mc(planner, buffer, check_result, initial_state, task_information, task_id, every_task_max_retries, every_task_max_planning_retries, planner_mode):

    ## Stage1: Workflow Decision
    # DIFFERENT MODE: action generation with rules; action and state generation with rules; action generation + world model with rules; action generation with rules + world model with rules
    workflow = planner.get_workflow(task_information=task_information, check_result=check_result, initial_state = initial_state, running_dataset = running_dataset, testModel = testModel)
    workflow_dict = {"workflow": workflow}
    return workflow_dict

##########################################################
## Debug env
##########################################################
biome_string = "taiga" # ["plains", "jungle", "taiga", "forest", "swampland" ice_plains],
seed = 3 #  2
world_seed = 18 # 0，, 6, 13, 14  18
# seed = random.randint(1,1000000000000)
# world_seed = random.randint(1,1000000000000)
start_time = 1000 # morning 1000; night 13000; midnight 18000
initial_weather =  "clear" # normal, clear, rain, thunder.
image_size=(512, 820)
vradius = 5
env = minedojo.make(
    "open-ended", # 
    # task_id="survival", # target_names="diamond", target_quantities = 1,
    image_size=image_size,
    break_speed_multiplier = 100.0, 
    use_voxel = True, 
    # voxel_size=dict(xmin=-5, ymin=-1, zmin=0, xmax=5, ymax=1, zmax=1), # doesn't really matter
    voxel_size=dict(xmin=-vradius, ymin=-vradius, zmin=-vradius, xmax=vradius, ymax=vradius, zmax=vradius), # doesn't really matter
    use_lidar=True,
    lidar_rays=[
            (np.pi * pitch / 180, np.pi * yaw / 180, 10) # 100
            for pitch in np.arange(-60, 60, 5)
            for yaw in np.arange(-60, 60, 5)
    ],
    world_seed= world_seed, 
    generate_world_type = "specified_biome", # specified_biome, flat
    # flat_world_seed_string = "3;minecraft:bedrock,2*minecraft:dirt,minecraft:grass;1;village(size=350 distance=9),biome_1(distance=9),decoration",
    seed=seed, # 3
    specified_biome = biome_string,
    initial_weather = initial_weather,
    start_time = start_time,
    initial_inventory = [
        # InventoryItem(slot=0, name="wooden_sword", variant=None, quantity=1), # diamond_sword
        # InventoryItem(slot=2, name="shears", variant=None, quantity=1),
        # InventoryItem(slot=3, name="bucket", variant=None, quantity=1),
        # InventoryItem(slot=4, name="wooden_pickaxe", variant=None, quantity=1),
        # InventoryItem(slot=5, name="diamond_pickaxe", variant=None, quantity=1),
        # InventoryItem(slot=6, name="furnace", variant=None, quantity=1), 
        # InventoryItem(slot=7, name="crafting_table", variant=None, quantity=1), 
        InventoryItem(slot=9, name="dirt", variant=None, quantity=60),
        InventoryItem(slot=10, name="dirt", variant=None, quantity=60),
        InventoryItem(slot=11, name="bucket", variant=None, quantity=60),
        # InventoryItem(slot=1, name="coal", variant=None, quantity=60), 
    ],
    # spawn_in_village = True,
)
env.reset()
##########################################################
## Debug env
##########################################################


# # ##########################################################
# # ## Debug env
# # ##########################################################
# seed = random.randint(1,1000000000000)
# vradius = 5
# log_info(seed)
# biome_string = "forest"

# env = minedojo.make(
#     task_id="harvest", target_names="diamond",
#     image_size=(512, 820), 
#     target_quantities=100, seed=3, 
#     specified_biome = biome_string, 
#     spawn_rate=1, 
#     break_speed_multiplier = 100.0, 
#     spawn_range_low=(-10, -10, -10), spawn_range_high=(10, 10, 10), 
#     start_at_night = False, world_seed = seed, use_voxel = True, 
#     voxel_size=dict(xmin=-vradius, ymin=-vradius, zmin=-vradius, xmax=vradius, ymax=vradius, zmax=vradius), # doesn't really matter
#     use_lidar=True,
#     lidar_rays=[
#             (np.pi * pitch / 180, np.pi * yaw / 180, 10) # ALERT: lidar range is now 10
#             for pitch in np.arange(-60, 60, 5)
#             for yaw in np.arange(-60, 60, 5)
#     ]
# )


# env.reset()

# env.set_inventory([InventoryItem(slot=9, name="dirt", variant=None, quantity=6), InventoryItem(slot=16, name="coal", variant=None, quantity=3), InventoryItem(slot=17, name="coal", variant=None, quantity=3)])
# events, _, _, _ = env.step([0,0,0,12,6,0,0,0])
# # ##########################################################
# # ## Debug env
# # ##########################################################




with open(task_dir, 'r') as f:
    task_list = json.load(f)
    # for task_id, task_information in enumerate(task_list[::-1]): # 逆序
    for task_id, task_information in enumerate(task_list[initialtaskID:], start=initialtaskID):

        with open(f'/home/**/Workspace/MP5/MP5_agent/agent/task_result/[tokenuse]_{running_dataset}_{testModel}.log', 'a') as f:
            f.write(f'\n\n***** MODEL: {testModel}, Dataset: {running_dataset} *****')  # 使用 json.dumps() 将字典转换为字符串
            f.write(f'\n\n***** task_id: {task_id}, task_information: {task_information} *****\n\n')  # 使用 json.dumps() 将字典转换为字符串


        task_id = task_id + roundtimes*30
        check_result = {}
        log_info(f"Task: { task_information['task'] }")

        ###############################################
        ## balance exploration & exploitation
        ###############################################
        # Decay epsilon, but not below the minimum value
        # current_epsilon = max(minimum_epsilon, initial_epsilon * (decay ** task_id))

        ###############################################
        ## balance exploration & exploitation
        ###############################################


        ###########################################################################################
        # task env initiallization ################################################################
        ###########################################################################################
        env.reset()

        # # armor
        # armorID = 00
        # armor = [0, 1, 0, 1, 0]
        # bias0 = 0 # 0, 1
        # label = task_id%len(armor)
        # if armor[label+bias0]:
        #     env.set_inventory(
        #         [
        #             InventoryItem(slot=36, name="diamond_boots", variant=None, quantity=1),
        #             InventoryItem(slot=37, name="diamond_leggings", variant=None, quantity=1),
        #             InventoryItem(slot=38, name="diamond_chestplate", variant=None, quantity=1),
        #             InventoryItem(slot=39, name="diamond_helmet", variant=None, quantity=1),
        #             # InventoryItem(slot=40, name="shield", variant=None, quantity=1)
        #         ]
        #         )
        #     armorID = 20
        armorID = 99
            
            
        # time
        bias1 = 0 # 0, 1, 2
        time = [1000, 18000, 18000, 1000]
        label = task_id%len(time)
        env.set_time(time[label+bias1]) # morning 1000; night 13000; midnight 18000
        if time == 1000:
            timeID = 'day'
        else:
            timeID = 'night'

        # # biome
        # env.random_teleport(800)

        # # weather
        # bias2 = 0
        # weather = ['normal', 'clear', 'rain', 'thunder', 'normal', 'thunder']
        # label = task_id%len(weather)
        # env.set_weather(weather[label+bias2]) # morning 1000; night 13000; midnight 18000


        # # weapon
        # wooden level

        # 00:# 'wooden_pickaxe'
        # 01:# 'wooden_axe'
        # 02:# 'wooden_hoe'
        # 03:# 'wooden_sword'
        # 04:# 'wooden_shovel'
        # if task_information['tasklevel'] in ['iron', 'golden']:
        #     env.set_inventory(
        #         [
        #             InventoryItem(slot=0, name="wooden_pickaxe", variant=None, quantity=1), # diamond_sword
        #             # InventoryItem(slot=1, name="wooden_axe", variant=None, quantity=1), # diamond_sword
        #             # InventoryItem(slot=2, name="wooden_hoe", variant=None, quantity=1), # diamond_sword
        #             InventoryItem(slot=3, name="wooden_sword", variant=None, quantity=1), # diamond_sword
        #             # InventoryItem(slot=4, name="wooden_shovel", variant=None, quantity=1), # diamond_sword
        #             # InventoryItem(slot=6, name="furnace", variant=None, quantity=1), 
        #             # InventoryItem(slot=7, name="crafting_table", variant=None, quantity=1), 
        #         ]
        #         )


        # give time to let the env initial change
        for i in range(45):
            obs, _, _, _ = env.step(env.action_space.no_op())
        events, _, _, _ = env.step(env.action_space.no_op())
        # sync the memory
        # share_memory(memory=memory, events=events)
        ###########################################################################################
        # task env initiallization ################################################################
        ###########################################################################################

        count = 0
        while count <= every_task_max_retries:
            count += 1

            # if count==4:
            #     env.set_inventory(
            #         [
            #             InventoryItem(slot=20, name="wooden_pickaxe", variant=None, quantity=1), # diamond_sword
            #             # InventoryItem(slot=1, name="wooden_axe", variant=None, quantity=1), # diamond_sword
            #             # InventoryItem(slot=2, name="wooden_hoe", variant=None, quantity=1), # diamond_sword
            #             InventoryItem(slot=21, name="wooden_sword", variant=None, quantity=1), # diamond_sword
            #             # InventoryItem(slot=4, name="wooden_shovel", variant=None, quantity=1), # diamond_sword
            #             # InventoryItem(slot=6, name="furnace", variant=None, quantity=1), 
            #             # InventoryItem(slot=7, name="crafting_table", variant=None, quantity=1), 
            #         ]
            #         )
                
            # if count==12:
            #     env.set_inventory(
            #         [
            #             InventoryItem(slot=20, name="wooden_pickaxe", variant=None, quantity=1), # diamond_sword
            #             # InventoryItem(slot=1, name="wooden_axe", variant=None, quantity=1), # diamond_sword
            #             # InventoryItem(slot=2, name="wooden_hoe", variant=None, quantity=1), # diamond_sword
            #             InventoryItem(slot=21, name="wooden_sword", variant=None, quantity=1), # diamond_sword
            #             # InventoryItem(slot=4, name="wooden_shovel", variant=None, quantity=1), # diamond_sword
            #             # InventoryItem(slot=6, name="furnace", variant=None, quantity=1), 
            #             # InventoryItem(slot=7, name="crafting_table", variant=None, quantity=1), 
            #         ]
            #         )
                
            # if count==6:
            #     if timeID == 'night': 
            #         env.set_time(1000)
            #     else:
            #         env.set_time(18000)
            #     for i in range(300):
            #         obs, _, _, _ = env.step(env.action_space.no_op())

            # if count==12:
            #     if timeID == 'night': 
            #         env.set_time(1000)
            #     else:
            #         env.set_time(1000)
            #     for i in range(300):
            #         obs, _, _, _ = env.step(env.action_space.no_op())


            if count > every_task_max_retries:
                log_info("************Failed to complete this task. Consider updating your prompt.************\n\n")
                log_info(f"**task_id:{task_id}, task_information:{task_information}, every_task_max_retries:{every_task_max_retries}**\n\n")
                break


            ## Stage0: gather initial state information
            for i in range(30):
                obs, _, _, _ = env.step(env.action_space.no_op())
            events, _, _, _ = env.step(env.action_space.no_op())
            state_initial = gather_state_info(obs)

            # sync the memory
            share_memory(memory=memory, events=events)
            log_info(f"My inventory: {memory.inventory}")


            ###############################################
            ## balance exploration & exploitation
            ###############################################
            # Decide whether to explore or exploit
            # if task_id >= 5:
            #     if np.random.rand() < current_epsilon:
            #         # just use prior knowledge
            #         rules_dir = None
            #     else:
            #         rules_dir = '/home/**/Workspace/MP5/MP5_agent/agent/buffer_rules/rules_library.json'
            # rules_dir = '/home/**/Workspace/MP5/MP5_agent/agent/buffer_rules/rules_library.json'
            # rules_dir = None
            # planner = Planner(memory=memory, mode = planner_mode, rules_dir = rules_dir, model_name=model_name, temperature=planner_temperature, choice_num=planner_choice_num)
            ###############################################
            ## balance exploration & exploitation
            ###############################################


            if planner_search_alg.startswith("original"):
                workflow_dict = Normal_mc(
                    planner=planner, 
                    buffer=buffer, 
                    check_result=check_result, 
                    initial_state=state_initial, 
                    task_information=task_information, 
                    task_id=task_id, 
                    every_task_max_retries=every_task_max_retries, 
                    every_task_max_planning_retries = every_task_max_planning_retries, 
                    planner_mode = planner_mode
                    )


            elif planner_search_alg == 'MCTS':
                world_model = MCWorldModelRAP(statepred = buffer, initial_state = state_initial, max_steps=10)
                config = MCConfigRAP(action_gen = planner)
                algorithm = MCTS(depth_limit=10, disable_tqdm=False, output_trace_in_each_iter=True, n_iters=3, uct_with_fast_reward=True)
                reasoner_rap = Reasoner(world_model=world_model, search_config=config, search_algo=algorithm)
                result_rap = reasoner_rap(task_information) 
                workflow_dict = {"workflow": result_rap[1]} # TODO remain to check
            
            elif planner_search_alg == 'MPC':

                workflow_dict = MPC_mc(
                    planner = planner, 
                    buffer = buffer, 
                    check_result = check_result, 
                    initial_state = state_initial, 
                    task_information = task_information, 
                    every_task_max_planning_retries = every_task_max_planning_retries
                    )
            else:
                raise NotImplementedError
            
            timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            file_path = f"/home/**/Workspace/MP5/MP5_agent/agent/actseqrec_{running_dataset}_{testModel}/traj_{task_id}/act_info_{task_information['task']}_taskID{task_id}_armor{armorID}_time{timeID}_count{count}_{timestamp}.json"
            os.makedirs(os.path.dirname(file_path), exist_ok=True)
            with open(file_path, 'w') as f:
                json.dump(workflow_dict, f, cls=NumpyEncoder, indent=4)

            ## Stage2: Interface & Update Inventory & save trajectories
            performer = Performer(memory=memory, percipient=percipient, checker=patroller)
            check_result, transition_list = performer.check_and_execute_workflow(env=env, workflow_dict=workflow_dict, task_information=task_information, task_id = task_id, every_task_max_retries = count)
            if 'died' in check_result['feedback']:
                break
            ####################
            # timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            # with open(f"/home/**/Workspace/MP5/MP5_agent/agent/buffer_traj/transition_info_{task_information['task']}_{task_id}_{every_task_max_retries}_{timestamp}.json", 'w') as f:
            #     json.dump(transition_list, f, cls=NumpyEncoder, indent=4)
            ####################
            ####################
            timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            file_path = f"/home/**/Workspace/MP5/MP5_agent/agent/buffer_traj_TEST_{testModel}/traj_{task_id}/transition_info_{task_information['task']}_taskID{task_id}_armor{armorID}_time{timeID}_count{count}_{timestamp}.json"
            os.makedirs(os.path.dirname(file_path), exist_ok=True)
            with open(file_path, 'w') as f:
                json.dump(transition_list, f, cls=NumpyEncoder, indent=4)
            ####################
            events, _, _, _ = env.step(env.action_space.no_op())
            # print(events['inventory']['name'].tolist())
            name = events['inventory']['name'].tolist()
            num  = events['inventory']['quantity'].tolist()
            if 'bucket' not in name:
                print('bucket disappear...')
                check_result["success"] = False
                break

            # Fail halfway through
            if not check_result["success"]:
                log_info(f"Action Preparation Failure: {check_result}")
                continue

            # Stage3: Check the final task result
            # check_result = patroller.check_task_success(task_information=task_information)

            events, _, _, _ = env.step(env.action_space.no_op())
            # print(events['inventory']['name'].tolist())
            name = events['inventory']['name'].tolist()
            num  = events['inventory']['quantity'].tolist()

            # return name, num
            success = check_item_availability(name, num, task_information['task'], task_information['quantity'])

            ## Stage4: Validation
            if success:
                # Success: Put successful Workflow into Memory
                # TODO to be check
                if module_trajmemory:
                    memory.add_successful_workflow(task_information["task"], task_information["task"], workflow_dict["workflow"])
                break
            else:
                # Failure: Do not have sufficient materials, Feedback
                continue

        # TODO to be check
        task_replan = count
        task_success = check_result["success"]
        write_taskresult_to_csv(planner_search_alg, interval, task_id, task_information['task'], task_information['tasklevel'], task_replan, task_success, testModel = testModel)


memory.reset_all()
log_info("############ Successfully Finish All Tasks ############")
env.close()

