import os.path
import random
import re
import numpy as np
import torch
import yaml
from gym_cooking.cooking_world.world_objects import *
from gym_cooking.cooking_world.abstract_classes import *
from gym_cooking.environment.cooking_zoo import Ingred2ID
from learning.model import LatentActor
import subprocess

static_objects = ['CutBoard','DeliverSquare','Divider','Plate']
ingredients = ['Lettuce','Tomato','Potato','Onion','Carrot']
dynamic_objects = ['Plate'] + ingredients
Action = ['Put','Take','Chop']
inf = 1e5


class event:
    def __init__(self, action, dynamic_obj, static_obj=None):
        assert ((action=='Take' or action=="Chop") and static_obj is None) or (action=='Put' and static_obj is not None)
        self.action = action
        self.dynamic_obj = dynamic_obj
        self.static_obj = static_obj
        self.available = False
        self.target_location = None
        self.from_divider = False
        self.done = False

    def __str__(self):
        return f'Event(action={self.action}, dynamic_obj={self.dynamic_obj}, static_obj={self.static_obj})'


class PretrainedPolicy:
    def __init__(self, model_path, agent_id, is_self_play=True, batch_size=1, device='cpu'):
        # Lazy init. The model is loaded only when the first observation is received, in the environment process
        # This guarantees that no tensor needs to be moved across processes
        self.model_path = model_path
        self.agent_id = agent_id
        self.actor = self.rnn_states = self.rnn_hidden_dim = None
        self.is_self_play = is_self_play
        self.batch_size = batch_size
        self.device = device

    def set_id(self, aid):
        self.agent_id = aid

    def __call__(self, obs):
        if not isinstance(obs, torch.Tensor):
            obs = torch.from_numpy(obs).float()
        if self.batch_size == 1:
            obs = obs.unsqueeze(0)
        with torch.no_grad():
            action, _, self.rnn_states, _ = self.actor.act(obs, self.rnn_states,
                                                           torch.ones(self.batch_size, 1, device=self.device),
                                                           None, deterministic=False)
        return action.item() if self.batch_size == 1 else action.squeeze(-1)

    def reset(self):
        if self.actor is None:
            # One-time initialization
            policy = torch.load(self.model_path, map_location=self.device)
            if self.is_self_play:
                assert len(policy.actors) == 2, 'Separate-model self-play policy should have exactly 2 actors'
                self.actor: LatentActor = policy.actors[self.agent_id]
            else:
                assert len(policy.actors) == 1, 'Joint policy should have exactly 1 actor'
                self.actor: LatentActor = policy.actors[0]
            if policy.is_recurrent:
                self.rnn_hidden_dim = policy.rnn_hidden_dim
                self.rnn_states = torch.zeros(self.batch_size, self.rnn_hidden_dim, device=self.device)
            if not hasattr(self.actor, 'rnn'):
                self.actor.rnn = None

        if self.rnn_states is not None:
            self.rnn_states.zero_()


class RuleBasedPolicy:
    def __init__(self, agent_id, policy_type, nav_p, tar_p, rand_p, convention, env_name, support_set=None, ingredient_support_set=None, event_probs=None):
        self.agent_id = None
        assert policy_type in ['minimum', 'medium', 'full', 'specified']
        assert convention in [0,1,2,None]  # only take plate if it is in certain position
        # minimum: only take ingredients to divider, take plate under satisfied recipe to delivery
        # medium: minimum + put ingredients into plate
        # full: all possible events
        self.policy_type = policy_type
        self.nav_p = nav_p  # the probability of moving right/left instead of up/down during navigation
        self.tar_p = tar_p  # the probability of choosing random target location instead of nearest 
        self.rand_p = rand_p # the probability of doing random actions instead of carrying out current plan
        assert (support_set is None and event_probs is None and policy_type != 'specified') or (len(support_set)==len(event_probs) and policy_type == 'specified')
        self.env_name = env_name
        if "divider" in env_name:
            if "large" in env_name:
                self.divider_loc = [(3,i) for i in range(1,12)]
                interval = 4
                upper = 12
            else:
                self.divider_loc = [(3,i) for i in range(1,6)]
                interval = 2
                upper = 6
            self.convention = [(3,i) for i in range(interval*convention+1,min(interval*(convention+1)+1,upper))] if convention is not None else None
        else:
            raise NotImplementedError
        if support_set is None:
            self.support_set = []
            for action in Action:
                for dynamic_obj in (dynamic_objects if ingredient_support_set is None else (ingredient_support_set + ['Plate'])):
                    if action=="Put":                            
                        for static_obj in static_objects:
                            if dynamic_obj == "Plate" and (static_obj == "Divider" or static_obj == "DeliverSquare"):
                                if self.policy_type == 'full':
                                    self.support_set.append(event(action, dynamic_obj, static_obj))
                                else:
                                    if static_obj == 'DeliverSquare':
                                        # delete task: put plate to divider
                                        self.support_set.append(event(action, dynamic_obj, static_obj))
                            elif dynamic_obj != "Plate" and static_obj != "DeliverSquare":
                                if self.policy_type == 'full':
                                    self.support_set.append(event(action, dynamic_obj, static_obj))
                                elif self.policy_type == 'medium':
                                    # delete task: put ingredient to cutboard
                                    if static_obj != 'CutBoard':
                                        self.support_set.append(event(action, dynamic_obj, static_obj))
                                elif self.policy_type == 'minimum':
                                    # delete task: put ingredient to cutboard/plate
                                    if static_obj == 'Divider':
                                        self.support_set.append(event(action, dynamic_obj, static_obj))
                    elif action=="Chop":
                        if dynamic_obj != "Plate" and self.policy_type == 'full':
                            self.support_set.append(event(action, dynamic_obj))
                    else:
                        self.support_set.append(event(action, dynamic_obj))
            self.event_probs = np.random.rand(len(self.support_set)) + 1
            self.event_probs = self.event_probs/np.sum(self.event_probs)
        else:
            self.support_set = support_set
            self.event_probs = event_probs
        self.ingredient_support_set = ingredients if ingredient_support_set is None else ingredient_support_set
        self.cur_event = None # current event
        #assert (event_probs is None) ^ (event_prio is None)
        # if event_probs is not None:
        #     assert len(support_set) == len(event_probs)
        # if event_prio is not None:
        #     assert len(support_set) == len(event_prio)

    def set_id(self, aid):
        self.agent_id = aid

    def distance(self, loc1, loc2):
        return np.abs(loc1[0]-loc2[0]) + np.abs(loc1[1]-loc2[1])

    def refresh_event(self):
        self.cur_event = None
        for e in self.support_set:
            e.available = False
            e.target_location = None
            e.done = False

    def in_ingredient_support(self, dynamic_obj):
        res = False
        for ingredient in self.ingredient_support_set:
            if isinstance(dynamic_obj, StringToClass[ingredient]):
                res = True
                break
        return res

    def is_reachable(self, agent_loc, target_loc):
        # whether the target_location is reachable from current agent location
        if "divider" in self.env_name:
            divider_x = self.divider_loc[0][0]
            if (agent_loc[0]-divider_x)*(target_loc[0]-divider_x)<0:
                return False
            else:
                return True
        else:
            raise NotImplementedError
    
    def is_movable(self, world, target_loc):
        # whether the target_location is only floor that agent can move onto
        static_obj = world.get_objects_at(target_loc, StaticObject)
        if len(static_obj)>1 or (len(static_obj)==1 and not isinstance(static_obj[0], Floor)):
            return False
        else:
            return True

    def is_event_available(self, e: event, world):
        # return the availability and if available, set the position of the target to the event
        action = e.action
        dynamic_obj = e.dynamic_obj
        static_obj = e.static_obj
        self.loc = world.agents[self.agent_id].location
        if action == "Put":
            if not world.agents[self.agent_id].holding:
                e.available = False
                return False
            elif not isinstance(world.agents[self.agent_id].holding, StringToClass[dynamic_obj]):
                e.available = False
                return False
            elif static_obj == "Divider":
                # if e.target_location is not None:
                #     #print("available check",e.target_location)
                #     if len(world.get_objects_at(e.target_location, DynamicObject))==0 and len(world.get_objects_at(e.target_location, CutBoard))==0:
                #         return True
                min_dist = inf
                avail_loc = []
                for loc in self.divider_loc:
                    if len(world.get_objects_at(loc, DynamicObject))==0 and len(world.get_objects_at(loc, CutBoard))==0:
                        # Can not put plate/ingredient onto a divider where there is ingredient/plate/cutboard on it
                        avail_loc.append(loc)
                        if self.distance(loc, self.loc)<min_dist:
                            min_dist = self.distance(self.loc, loc)
                            e.target_location = loc
                if min_dist == inf:
                    e.available = False
                    return False
                else:
                    e.available = True
                    # if there are multiple available locations, with probability tar_p randomly choose one
                    if np.random.rand()<self.tar_p and len(avail_loc)>1:
                        e.target_location = avail_loc[np.random.randint(len(avail_loc))]
                    return True
            elif static_obj == "Plate":
                if world.agents[self.agent_id].holding.done():
                    # can only put a chopped food into a plate
                    min_dist = inf
                    for plate in world.world_objects['Plate']:
                            loc = plate.location
                            if self.is_reachable(self.loc, loc) and self.distance(self.loc, loc)<min_dist:
                                min_dist = self.distance(self.loc, loc)
                                e.target_location = loc
                    if min_dist == inf:
                        e.available = False
                        return False
                    else:
                        e.available = True
                        return True
                else:
                    return False
            elif static_obj == "CutBoard":
                if world.agents[self.agent_id].holding.done():
                    # cannot chop a chopped food
                    e.available = False
                    return False
                else:
                    min_dist = inf
                    for cutboard in world.world_objects['CutBoard']:
                        loc = cutboard.location
                        if self.is_reachable(self.loc, loc) and len(world.get_objects_at(loc, DynamicObject))==0 and self.distance(self.loc, loc)<min_dist:
                            min_dist = self.distance(self.loc, loc)
                            e.target_location = loc
                    if min_dist == inf:
                        e.available = False
                        return False
                    else:
                        e.available = True
                        return True
            elif static_obj == "DeliverSquare":
                if len(world.agents[self.agent_id].holding.content)==0:
                    # empty plate cannot be put onto deliversquare
                    e.available = False
                    return False
                else:
                    min_dist = inf
                    for deliver_square in world.world_objects['DeliverSquare']:
                        loc = deliver_square.location
                        if self.is_reachable(self.loc, loc) and len(world.get_objects_at(loc, DynamicObject))==0 and self.distance(self.loc, loc)<min_dist:
                            min_dist = self.distance(self.loc, loc)
                            e.target_location = loc
                    if min_dist == inf:
                        e.available = False
                        return False
                    else:
                        e.available = True
                        return True
        elif action == "Chop":
            if world.agents[self.agent_id].holding:
                e.available = False
                return False
            else:
                min_dist = inf
                for food in world.world_objects[dynamic_obj]:
                    loc = food.location
                    if not food.done() and self.is_reachable(self.loc, loc):
                        for cutboard in world.world_objects["CutBoard"]:
                            if loc == cutboard.location:
                                min_dist = self.distance(self.loc, loc)
                                e.target_location = loc
                if min_dist == inf:
                    e.available = False
                    return False
                else:
                    e.available = True
                    return True
        else:
            if world.agents[self.agent_id].holding:
                # Can't take anything when holding something
                e.available = False
                return False
            else:
                min_dist = inf
                for obj in world.world_objects[dynamic_obj]:
                    loc = obj.location
                    if dynamic_obj !="Plate" and len(world.get_objects_at(loc, DynamicObject))>1:
                        # Can not take ingredients already in a plate, instead take the plate
                        continue
                    elif dynamic_obj =="Plate":
                        # Can't take a plate unless to serve a dish
                        in_plate = world.get_objects_at(loc, Food)
                        if len(in_plate) == 0 or (len(in_plate) == 1 and len(self.ingredient_support_set) > 1):
                            # recipe not satisfied
                            continue
                        # Can't take a plate if not in position matching convention
                        if self.convention is not None and not loc in self.convention:
                            continue                    
                        # Can not take a plate with ingredients out of ingredient support set
                        out_ingredient_support_set = False
                        for ingred in in_plate:
                            if not self.in_ingredient_support(ingred):
                                out_ingredient_support_set = True
                                break
                        if out_ingredient_support_set:
                            continue
                    else:
                        # Take an ingredient. If medium or minimum policy, can't take an ingredient from the divider to somewhere else
                        # NOTE: this only works for the right hand side of the divider
                        if (self.policy_type == 'medium' or self.policy_type == 'minimum') and (loc in self.divider_loc):
                            continue
                    if self.is_reachable(self.loc, loc) and self.distance(self.loc, loc)<min_dist:
                        min_dist = self.distance(self.loc, loc)
                        e.target_location = loc
                        e.from_divider = loc in self.divider_loc
                if min_dist == inf:
                    e.available = False
                    return False
                else:
                    e.available = True
                    return True
    
    def check_still_available(self, e:event, world):
        # If current event is not done, check whether it is still available every time step
        dynamic_object = e.dynamic_obj
        static_object = e.static_obj
        target_location = e.target_location
        if e.action == "Take":
            # check whether the object is still at the target location
            if len(world.get_objects_at(target_location, StringToClass[dynamic_object]))==0:
                return False
            else:
                return True
        elif e.action == "Put":
            # check whether current object can be put to target location
            if static_object != "Plate":
                # Cannot put something on Cutboard/Divider/DeliverSquare when there is already something on it
                if len(world.get_objects_at(target_location, DynamicObject))>0:
                    return False
                else:
                    return True
            else:
                # Cannot conitune to put ingredients on a plate if the plate is gone
                if len(world.get_objects_at(target_location, Plate))==0:
                    return False
                else:
                    return True
        elif e.action == "Chop":
            # Cannot chop if the food is gone or is already chopped
            if len(world.get_objects_at(target_location, StringToClass[dynamic_object]))==0 or world.get_objects_at(target_location, StringToClass[dynamic_object])[0].done():
                return False
            else:
                return True
        else:
            raise ValueError("Unknown action")

    def get_available_events(self, world):
        return [e for e in self.support_set if self.is_event_available(e, world)]

    def select_available_event(self, world):
        cur_prob = np.zeros(len(self.support_set))
        for (i,e) in enumerate(self.support_set):
            if self.is_event_available(e, world):
                cur_prob[i] = self.event_probs[i]
        if np.sum(cur_prob)==0:
            return None
        cur_prob = cur_prob/np.sum(cur_prob) # normalization
        idx = np.random.choice(len(self.support_set),p=cur_prob)
        # print(cur_prob)
        return self.support_set[idx]

    def get_random_action(self):
        random_action = np.random.randint(5)
        #print(self.cur_event,"random action",random_action)
        return random_action

    def get_action_from_event(self, world):
        if "divider" in self.env_name:
            agent_loc = world.agents[self.agent_id].location
            agent_orientation = world.agents[self.agent_id].orientation
            target_loc = self.cur_event.target_location
            if self.distance(agent_loc,target_loc)>1:
                # still need to move
                action_list = []
                if target_loc[0]>agent_loc[0] and self.is_movable(world,(agent_loc[0]+1,agent_loc[1])):
                    action_list.append(2)
                elif target_loc[0]<agent_loc[0] and self.is_movable(world,(agent_loc[0]-1,agent_loc[1])):
                    action_list.append(1)
                if target_loc[1]>agent_loc[1] and self.is_movable(world,(agent_loc[0],agent_loc[1]+1)):
                    action_list.append(3)
                elif target_loc[1]<agent_loc[1] and self.is_movable(world,(agent_loc[0],agent_loc[1]-1)):
                    action_list.append(4)
                if len(action_list)==1:
                    action = action_list[0]
                else:
                    action = action_list[0] if np.random.rand()<self.nav_p else action_list[1]
            else:
                # only need to change orientation if necessary
                if target_loc[0]==agent_loc[0]:
                    if target_loc[1]>agent_loc[1]:
                        orientation = 3
                    else:
                        orientation = 4
                else:
                    if target_loc[0]>agent_loc[0]:
                        orientation = 2
                    else:
                        orientation = 1
                if agent_orientation == orientation:
                    action = 5
                    self.cur_event.done = True
                else:
                    action = orientation
        else:
            raise NotImplementedError
        # with rand_p probability choose random action instead of carrying out current plan
        #print(self.cur_event, self.cur_event.target_location, action)
        if np.random.rand()<self.rand_p:
            action = self.get_random_action()
        
        return action

    def reset(self):
        pass
    
    def __call__(self, world):
        # print("Current event:",self.cur_event)
        if self.cur_event is not None:
            # judge whether the current event has become invalid
            if not self.check_still_available(self.cur_event, world):
                self.cur_event = None
        if self.cur_event is None:
            # select a new event
            new_event = self.select_available_event(world)
            # print("New event:",new_event)
            if new_event is None:
                # No available event, do some random actions
                action = self.get_random_action()
            else:
                self.cur_event = new_event 
                action = self.get_action_from_event(world)
        else:
            action = self.get_action_from_event(world)
        # print('Executing event', self.cur_event)

        if self.cur_event is not None and self.cur_event.done:
            if self.cur_event.action == "Put" and self.cur_event.static_obj == "CutBoard":
                self.cur_event = event("Chop", self.cur_event.dynamic_obj)
            elif self.cur_event.action == "Chop":
                self.cur_event = event("Take", self.cur_event.dynamic_obj)
            # elif self.cur_event.action == "Take" and self.cur_event.dynamic_obj != "Plate":
            #     take_obj = world.get_objects_at(self.cur_event.target_location, DynamicObject)
            #     if len(take_obj) == 1 and not take_obj[0].done():
            #         self.cur_event = event("Put", self.cur_event.dynamic_obj, "CutBoard")
            #     else:
            #         self.cur_event = None
            #         self.refresh_event()
            else:
                self.cur_event = None
                self.refresh_event()        
        
        return action

        # if np.random.rand() < self.p:
        #     return self.get_random_action()
        # available_events = self.get_available_events(obs)
        # if len(available_events) == 0:
        #     return self.get_random_action()
        # if self.event_prio is not None:
        #     for e in self.event_prio:
        #         if e in available_events:
        #             selected_event = e
        #             break
        #     else:
        #         raise ValueError('ERROR: available event not found')
        # else:
        #     selected_event = np.random.choice(self.support_set, p=self.event_probs)
        # return self.get_action_from_event(selected_event)


def get_train_eval_pool(args):
    assert args.env_name == 'Overcooked'
    if args.desire_id is not None:
        assert args.desire_id < 2 ** 5, f'Desire id out of range: {args.desire_id}'
        policy_pool_train = [[((args.desire_id >> i) & 1) for i in range(5)]]
        policy_pool_eval = []
        self_play_opponents = 0
        print('Put 1 desire into train pool')
    elif args.rule_based_opponents > 0 or args.eval_pool_size > 0:
        with open(args.env_config, 'r') as env_config_file:
            env_map = yaml.safe_load(env_config_file)['mode']
        print('Using map', env_map)
        policy_pool_train_eval = generate_policy_pool(args.multi_agent > 1, args.p, 1 - args.player_id, env_map,
                                                      args.rule_based_opponents + args.eval_pool_size)
        policy_pool_train = policy_pool_train_eval[:args.rule_based_opponents]
        policy_pool_eval = policy_pool_train_eval[args.rule_based_opponents:]
        self_play_opponents = args.train_pool_size - args.rule_based_opponents
        print('Put', len(policy_pool_train), 'rule-based opponents into train pool, ingredient support sets:',
              [p.ingredient_support_set for p in policy_pool_train])
        print('Put', len(policy_pool_eval), 'rule-based opponents into eval pool, ingredient support sets:',
              [p.ingredient_support_set for p in policy_pool_eval])
    else:
        policy_pool_train = []
        policy_pool_eval = []
        self_play_opponents = args.train_pool_size
    if self_play_opponents > 0:
        assert 'potato_hard' in args.env_config, f'Loading potato hard fcp checkpoints for map {args.env_config}'
        self_play_pool = load_potato_hard_self_play_policy_pool(1 - args.player_id)
        assert len(self_play_pool) >= self_play_opponents, \
            f'Requesting {self_play_opponents} self-play opponents, got {len(self_play_pool)}'
        policy_pool_train.extend(self_play_pool[:self_play_opponents])
        print('Put', self_play_opponents, 'self-play opponents into train pool, model paths:',
              [p.model_path for p in self_play_pool[:self_play_opponents]])
    return policy_pool_train, policy_pool_eval


def generate_policy_pool(gen_desire, p_max, player_id, env_name, pool_size):
    old_state = np.random.get_state()
    np.random.seed(1)

    left_ingred = ingredients[:3]
    right_ingred = ingredients[3:]
    ingredient_sets_all = []
    # for i in range(1, 2 ** len(ingredients)):
    #     if i > (i & -i):
    #         ingredient_sets_all.append([ingredients[j] for j in range(len(ingredients)) if (i >> j) & 1])
    # one ingredient in left
    for ingred in left_ingred:
        ingredient_sets_all.append([ingred])
    # two ingredient in left
    for i in range(len(left_ingred)):
        for j in range(i+1, len(left_ingred)):
            ingredient_sets_all.append([left_ingred[i], left_ingred[j]])
    # one ingredient in left, one in right
    for ingred in left_ingred:
        for ingred2 in right_ingred:
            ingredient_sets_all.append([ingred, ingred2])
    print('All ingredient support sets:', ingredient_sets_all)
    print('Corresponding indices:', [sum(1 << Ingred2ID[ingred] for ingred in ingred_set) for ingred_set in ingredient_sets_all])
    pool = []
    all_policy_indices = np.arange(len(ingredient_sets_all))
    np.random.shuffle(all_policy_indices)
    for i in range(pool_size):
        if i >= len(all_policy_indices):
            i = i % len(all_policy_indices)
        ingredient_support_set = ingredient_sets_all[all_policy_indices[i] % len(ingredient_sets_all)]
        print(f'Policy generated with support set {ingredient_support_set} and convention {None}')
        if gen_desire:
            policy = [int(ing in ingredient_support_set) for ing in ingredients]
        else:
            policy = RuleBasedPolicy(player_id, 'minimum', np.random.rand() * p_max, np.random.rand() * p_max, 0.05, None, env_name,
                                     ingredient_support_set=ingredient_support_set)
        pool.append(policy)

    np.random.set_state(old_state)
    return pool


def load_self_play_policy_pool(player_id):
    print('Loading pretrained self-play opponents...')

    good_policies = []
    mid_policies = []
    for i in range(1, 16):
        run_path = './data/Overcooked/fcp_checkpoints/'
        good_policies.append(PretrainedPolicy(os.path.join(run_path, f'{i}_final.pt'), player_id))
        mid_policies.append(PretrainedPolicy(os.path.join(run_path, f'{i}_mid.pt'), player_id))
    all_policies = good_policies + mid_policies
    print(f'{len(all_policies)} policies loaded.')
    # Fix partition across runs
    old_state = random.getstate()
    random.seed(1)
    random.shuffle(all_policies)
    random.setstate(old_state)
    train2, eval2 = load_self_play_policy_pool2(player_id)
    return all_policies[:25] + train2, all_policies[25:] + eval2


def load_self_play_policy_pool2(player_id):
    print('Loading pretrained self-play opponents, batch 2...')
    good_policies = []
    mid_policies = []
    for i in range(1, 16):
        run_path = './data/Overcooked/fcp_checkpoints/'
        good_policies.append(PretrainedPolicy(os.path.join(run_path, f'{i}_run2_1_4.pt'), player_id))
        mid_policies.append(PretrainedPolicy(os.path.join(run_path, f'{i}_run2_3_4.pt'), player_id))
    all_policies = good_policies + mid_policies
    print(f'{len(all_policies)} policies in 2nd batch loaded.')
    # Fix partition across runs
    old_state = random.getstate()
    random.seed(1)
    random.shuffle(all_policies)
    random.setstate(old_state)
    return all_policies[:25], all_policies[25:]


def load_good_self_play_policy_pool(player_id):
    print('Loading pretrained good self-play opponents...')
    good_policies = []
    good_policies2 = []
    for i in range(1, 16):
        run_path = './data/Overcooked/fcp_checkpoints/'
        good_policies.append(PretrainedPolicy(os.path.join(run_path, f'{i}_final.pt'), player_id))
        good_policies2.append(PretrainedPolicy(os.path.join(run_path, f'{i}_run2_final.pt'), player_id))
    print(f'{len(good_policies + good_policies2)} policies loaded.')
    # Fix partition across runs
    old_state = random.getstate()
    random.seed(1)
    random.shuffle(good_policies)
    random.shuffle(good_policies2)
    random.setstate(old_state)
    return good_policies + good_policies2[:10], good_policies2[10:]


def load_potato_self_play_policy_pool(player_id):
    print('Loading pretrained good self-play opponents on the potato map...')
    good_policies = []
    for i in range(12):
        run_path = './data/Overcooked/fcp_checkpoints/'
        good_policies.append(PretrainedPolicy(os.path.join(run_path, f'rule_potato_{i}_final.pt'), player_id))
    print(f'{len(good_policies)} policies loaded.')
    # Fix partition across runs
    old_state = random.getstate()
    random.seed(1)
    random.shuffle(good_policies)
    random.setstate(old_state)
    return good_policies


def load_potato_hard_self_play_policy_pool(player_id):
    print('Loading pretrained good self-play opponents on the potato hard map...')
    good_policies = []
    for i in range(12):
        run_path = './data/Overcooked/fcp_checkpoints/'
        good_policies.append(PretrainedPolicy(os.path.join(run_path, f'rule_potato_{i}_final.pt'), player_id))
    print(f'{len(good_policies)} policies loaded.')
    # Fix partition across runs
    old_state = random.getstate()
    random.seed(1)
    random.shuffle(good_policies)
    random.setstate(old_state)
    return good_policies


def aggregate_self_play_policies():
    for i in range(1, 16):
        if i in {10, 11}:
            seed = 3
            steps = 20
        elif i in {5, 9, 12, 13}:
            seed = 2
            steps = 6
        else:
            seed = 1
            steps = 6
        run_path = f'./logs/Overcooked/ppo_test_desire{i}_fcp_sep_dense_rotate_2_sp01_p0_{steps}m_128_128_step200_proc64_epoch15_batch8_ent3en2_lr2en3_clip02_relu_eps1en8_tl_seed{seed}/ppo'
        subprocess.run(['cp', os.path.join(run_path, 'individual_latest.pt'), f'./data/Overcooked/fcp_checkpoints/{i}_final.pt'])
        subprocess.run(['cp', os.path.join(run_path, 'fcp_mid.pt'), f'./data/Overcooked/fcp_checkpoints/{i}_mid.pt'])


def aggregate_self_play_policies_run2():
    for i in range(1, 16):
        if i == 5:
            seed = 13
            steps = 20
        elif i in {9, 10, 11, 12, 13, 14}:
            seed = 12
            steps = 20
        else:
            seed = 11
            steps = 6
        run_path = f'./logs/Overcooked/ppo_test_desire{i}_fcp_sep_dense_rotate_2_sp01_p0_{steps}m_128_128_step200_proc64_epoch15_batch8_ent3en2_lr2en3_clip02_relu_eps1en8_tl_seed{seed}/ppo'
        subprocess.run(['cp', os.path.join(run_path, 'individual_latest.pt'), f'./data/Overcooked/fcp_checkpoints/{i}_run2_final.pt'])
        fcp_set = set()
        for fn in os.listdir(run_path):
            if fn.startswith('fcp_'):
                assert re.fullmatch(r'fcp_\d_\d.pt', fn) is not None
                assert int(fn[6]) == 4
                fcp_set.add(int(fn[4]))
                # print(run_path, fn, int(fn[4]), int(fn[6]))
                # for m in re.search(r'\d+', fn).groups():
                #     print(m)
                subprocess.run(['cp', os.path.join(run_path, fn), f'./data/Overcooked/fcp_checkpoints/{i}_run2_{fn[4:]}'])
        print(run_path, fcp_set)


def aggregate_best_responses():
    for opp_id in range(19,25):
        run_path = f'./logs/Overcooked/ppo_test_opp{opp_id}_lhsrl_dense_rotate_2_sp01_p0_15m_128_128_step200_proc64_epoch15_batch8_ent3en2_lr2en3_clip02_relu_eps1en8_tl_seed1/ppo/individual_latest.pt'
        dest_path = f'./data/Overcooked/br_checkpoints/opp{opp_id}_final.pt'
        print('Copying', run_path, 'to', dest_path)
        subprocess.run(['cp', run_path, dest_path])


def aggregate_rule_potato_best_responses():
    for opp_id in range(25):
        if opp_id in [8, 21]:
            seed = 3
            run_length = 5
        elif opp_id in [5, 9, 10, 17]:
            seed = 2
            run_length = 2
        else:
            seed = 1
            run_length = 1
        run_path = f'./logs/Overcooked/ppo_test_rule_potato_opp{opp_id}_lhsrl_dense_sp01_p1_{run_length}m_128_128_step100_proc64_epoch8_batch8_ent3en2_lr1en3_clip02_relu_eps1en8_tl_seed{seed}/ppo/individual_latest.pt'
        dest_path = f'./data/Overcooked/br_checkpoints/rule_potato_opp{opp_id}_final.pt'
        print('Copying', run_path, 'to', dest_path)
        subprocess.run(['cp', run_path, dest_path])


def aggregate_rule_potato_self_play_policies():
    desire_indices = [4, 2, 16, 6, 20, 18, 5, 12, 3, 10, 17, 24]
    for i, d in enumerate(desire_indices):
        run_path = f'./logs/Overcooked/ppo_test_desire{d}_fcp_h100_sep_dense_potato_sp01_5m_128_128_step200_proc64_epoch15_batch8_ent3en2_lr2en3_clip02_relu_eps1en8_tl_seed1/ppo'
        subprocess.run(['cp', os.path.join(run_path, 'individual_latest.pt'), f'./data/Overcooked/fcp_checkpoints/rule_potato_{i}_final.pt'])
        fcp_set = set()
        for fn in os.listdir(run_path):
            if fn.startswith('fcp_'):
                assert re.fullmatch(r'fcp_\d_\d.pt', fn) is not None
                assert int(fn[6]) == 4
                fcp_set.add(int(fn[4]))
                subprocess.run(['cp', os.path.join(run_path, fn), f'./data/Overcooked/fcp_checkpoints/rule_potato_{i}_{fn[4:]}'])
        print(run_path, fcp_set)


def aggregate_rule_potato_hard_self_play_policies():
    desire_indices = [4, 2, 16, 6, 20, 18, 5, 12, 3, 10, 17, 24]
    for i, d in enumerate(desire_indices):
        run_path = f'./logs/Overcooked/ppo_test_desire{d}_fcp_h100_sep_dense_potato_hard_sp01_3m_128_128_step200_proc64_epoch15_batch8_ent3en2_lr2en3_clip02_relu_eps1en8_tl_seed1/ppo'
        subprocess.run(['cp', os.path.join(run_path, 'individual_latest.pt'), f'./data/Overcooked/fcp_checkpoints/rule_potato_hard_{i}_final.pt'])
        fcp_set = set()
        for fn in os.listdir(run_path):
            if fn.startswith('fcp_'):
                assert re.fullmatch(r'fcp_\d_\d.pt', fn) is not None
                assert int(fn[6]) == 4
                fcp_set.add(int(fn[4]))
                subprocess.run(['cp', os.path.join(run_path, fn), f'./data/Overcooked/fcp_checkpoints/rule_potato_hard_{i}_{fn[4:]}'])
        print(run_path, fcp_set)


# TODO: pick a test policy pool
test_policies = None
