from smacv2.env.starcraft2.distributions import get_distribution
from smacv2.env.starcraft2.starcraft2 import StarCraft2Env, CannotResetException
from smacv2.env import MultiAgentEnv

share_start_position = {
    "dist_type": "surrounded_and_reflect",
    "p": 0.5,
    "map_x": 32,
    "map_y": 32,
}
share_team_gen_terran = {
    "dist_type": "weighted_teams",
    "unit_types": ["marine", "marauder", "medivac"],
    "exception_unit_types": ["medivac"],
    "weights": [0.45, 0.45, 0.1],
    "observe": True,
}
share_team_gen_protoss = {
    "dist_type": "weighted_teams",
    "unit_types": ["stalker", "zealot", "colossus"],
    "weights": [0.45, 0.45, 0.1],
    "observe": True,
}
share_team_gen_zerg = {
    "dist_type": "weighted_teams",
    "unit_types": ["zergling", "baneling", "hydralisk"],
    "exception_unit_types": ["baneling"],
    "weights": [0.45, 0.1, 0.45],
    "observe": True,
}
share_team_gen_terran_meta1 = {
    "dist_type": "weighted_teams",
    "unit_types": ["marine", "marauder", "medivac"],
    "exception_unit_types": ["medivac"],
    "weights": [0.5, 0.5, 0.0],
    "observe": True,
}
share_team_gen_terran_meta2 = {
    "dist_type": "weighted_teams",
    "unit_types": ["marine", "marauder", "medivac"],
    "exception_unit_types": ["medivac"],
    "weights": [0.5, 0.0, 0.5],
    "observe": True,
}
share_team_gen_terran_meta3 = {
    "dist_type": "weighted_teams",
    "unit_types": ["marine", "marauder", "medivac"],
    "exception_unit_types": ["medivac"],
    "weights": [0.0, 0.5, 0.5],
    "observe": True,
}
share_team_gen_protoss_meta1 = {
    "dist_type": "weighted_teams",
    "unit_types": ["stalker", "zealot", "colossus"],
    "weights": [0.5, 0.5, 0.0],
    "observe": True,
}
share_team_gen_protoss_meta2 = {
    "dist_type": "weighted_teams",
    "unit_types": ["stalker", "zealot", "colossus"],
    "weights": [0.5, 0.0, 0.5],
    "observe": True,
}
share_team_gen_protoss_meta3 = {
    "dist_type": "weighted_teams",
    "unit_types": ["stalker", "zealot", "colossus"],
    "weights": [0.0, 0.5, 0.5],
    "observe": True,
}
share_team_gen_zerg_meta1 = {
    "dist_type": "weighted_teams",
    "unit_types": ["zergling", "baneling", "hydralisk"],
    "exception_unit_types": ["baneling"],
    "weights": [0.5, 0.0, 0.5],
    "observe": True,
}
share_team_gen_zerg_meta2 = {
    "dist_type": "weighted_teams",
    "unit_types": ["zergling", "baneling", "hydralisk"],
    "exception_unit_types": ["baneling"],
    "weights": [0.5, 0.5, 0.0],
    "observe": True,
}
share_team_gen_zerg_meta3 = {
    "dist_type": "weighted_teams",
    "unit_types": ["zergling", "baneling", "hydralisk"],
    "exception_unit_types": ["baneling"],
    "weights": [0.0, 0.5, 0.5],
    "observe": True,
}

task2team_base = {
    "terran_5_vs_5":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": share_team_gen_terran,
        "start_positions": share_start_position,
    },
    "terran_10_vs_10":{
        "n_units": 10,
        "n_enemies": 10,
        "team_gen": share_team_gen_terran,
        "start_positions": share_start_position,
    },
    "terran_10_vs_11":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": share_team_gen_terran,
        "start_positions": share_start_position,
    },
    "terran_20_vs_20":{
        "n_units": 20,
        "n_enemies": 20,
        "team_gen": share_team_gen_terran,
        "start_positions": share_start_position,
    },
    "terran_20_vs_23":{
        "n_units": 20,
        "n_enemies": 23,
        "team_gen": share_team_gen_terran,
        "start_positions": share_start_position,
    },
    "protoss_5_vs_5":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": share_team_gen_protoss,
        "start_positions": share_start_position,
    },
    "protoss_10_vs_10":{
        "n_units": 10,
        "n_enemies": 10,
        "team_gen": share_team_gen_protoss,
        "start_positions": share_start_position,
    },
    "protoss_10_vs_11":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": share_team_gen_protoss,
        "start_positions": share_start_position,
    },
    "protoss_20_vs_20":{
        "n_units": 20,
        "n_enemies": 20,
        "team_gen": share_team_gen_protoss,
        "start_positions": share_start_position,
    },
    "protoss_20_vs_23":{
        "n_units": 20,
        "n_enemies": 23,
        "team_gen": share_team_gen_protoss,
        "start_positions": share_start_position,
    },
    "zerg_5_vs_5":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": share_team_gen_zerg,
        "start_positions": share_start_position,
    },
    "zerg_10_vs_10":{
        "n_units": 10,
        "n_enemies": 10,
        "team_gen": share_team_gen_zerg,
        "start_positions": share_start_position,
    },
    "zerg_10_vs_11":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": share_team_gen_zerg,
        "start_positions": share_start_position,
    },
    "zerg_20_vs_20":{
        "n_units": 20,
        "n_enemies": 20,
        "team_gen": share_team_gen_zerg,
        "start_positions": share_start_position,
    },
    "zerg_20_vs_23":{
        "n_units": 20,
        "n_enemies": 23,
        "team_gen": share_team_gen_zerg,
        "start_positions": share_start_position,
    },
}
task2team_meta_train_terran = {
    "terran_5_vs_5_meta1":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": share_team_gen_terran_meta1,
        "start_positions": share_start_position,
    },
    "terran_5_vs_5_meta2":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": share_team_gen_terran_meta2,
        "start_positions": share_start_position,
    },
    "terran_10_vs_10_meta3":{
        "n_units": 10,
        "n_enemies": 10,
        "team_gen": share_team_gen_terran_meta3,
        "start_positions": share_start_position,
    },
    "terran_10_vs_11_meta1":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": share_team_gen_terran_meta1,
        "start_positions": share_start_position,
    },
    "terran_10_vs_11_meta2":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": share_team_gen_terran_meta2,
        "start_positions": share_start_position,
    },
}
task2team_meta_train_protoss = {
    "protoss_5_vs_5_meta1":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": share_team_gen_protoss_meta1,
        "start_positions": share_start_position,
    },
    "protoss_5_vs_5_meta2":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": share_team_gen_protoss_meta2,
        "start_positions": share_start_position,
    },
    "protoss_10_vs_10_meta3":{
        "n_units": 10,
        "n_enemies": 10,
        "team_gen": share_team_gen_protoss_meta3,
        "start_positions": share_start_position,
    },
    "protoss_10_vs_11_meta1":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": share_team_gen_protoss_meta1,
        "start_positions": share_start_position,
    },
    "protoss_10_vs_11_meta2":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": share_team_gen_protoss_meta2,
        "start_positions": share_start_position,
    },
}
task2team_meta_train_zerg = {
    "zerg_5_vs_5_meta1":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": share_team_gen_zerg_meta1,
        "start_positions": share_start_position,
    },
    "zerg_5_vs_5_meta2":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": share_team_gen_zerg_meta2,
        "start_positions": share_start_position,
    },
    "zerg_10_vs_10_meta3":{
        "n_units": 10,
        "n_enemies": 10,
        "team_gen": share_team_gen_zerg_meta3,
        "start_positions": share_start_position,
    },
    "zerg_10_vs_11_meta1":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": share_team_gen_zerg_meta1,
        "start_positions": share_start_position,
    },
    "zerg_10_vs_11_meta2":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": share_team_gen_zerg_meta2,
        "start_positions": share_start_position,
    },
}
task2team_meta_test_terran = {
    "terran_5_vs_5_test1":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['marauder', 'marauder', 'marauder', 'medivac', 'marine'],
            "enemy_team": ['marauder', 'marauder', 'marauder', 'medivac', 'marine'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "terran_5_vs_5_test2":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['medivac', 'marine', 'marine', 'marauder', 'marauder'],
            "enemy_team": ['medivac', 'marine', 'marine', 'marauder', 'marauder'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "terran_5_vs_5_test3":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['medivac', 'marine', 'marauder', 'marine', 'marine'],
            "enemy_team": ['medivac', 'marine', 'marauder', 'marine', 'marine'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "terran_10_vs_10_test1":{
        "n_units": 10,
        "n_enemies": 10,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['medivac', 'medivac', 'marine', 'marauder', 'marauder', 'medivac', 'marine', 'marine', 'medivac', 'marauder'],
            "enemy_team": ['medivac', 'medivac', 'marine', 'marauder', 'marauder', 'medivac', 'marine', 'marine', 'medivac', 'marauder'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "terran_10_vs_10_test2":{
        "n_units": 10,
        "n_enemies": 10,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['marine', 'marauder', 'marauder', 'medivac', 'marauder', 'medivac', 'marauder', 'medivac', 'medivac', 'marauder'],
            "enemy_team": ['marine', 'marauder', 'marauder', 'medivac', 'marauder', 'medivac', 'marauder', 'medivac', 'medivac', 'marauder'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "terran_10_vs_10_test3":{
        "n_units": 10,
        "n_enemies": 10,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['marine', 'marine', 'medivac', 'marauder', 'medivac', 'marine', 'medivac', 'marine', 'medivac', 'marine'],
            "enemy_team": ['marine', 'marine', 'medivac', 'marauder', 'medivac', 'marine', 'medivac', 'marine', 'medivac', 'marine'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "terran_10_vs_11_test1":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['medivac', 'marine', 'marauder', 'marauder', 'marine', 'marine', 'marauder', 'medivac', 'marine', 'marine'],
            "enemy_team": ['medivac', 'marine', 'marauder', 'marauder', 'marine', 'marine', 'marauder', 'medivac', 'marine', 'marine', 'marine'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "terran_10_vs_11_test2":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['marine', 'marine', 'marauder', 'marauder', 'medivac', 'marauder', 'marauder', 'marauder', 'medivac', 'marauder'],
            "enemy_team": ['marine', 'marine', 'marauder', 'marauder', 'medivac', 'marauder', 'marauder', 'marauder', 'medivac', 'marauder', 'marine'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "terran_10_vs_11_test3":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['medivac', 'marine', 'medivac', 'marine', 'marine', 'marine', 'marauder', 'marauder', 'medivac', 'marauder'],
            "enemy_team": ['medivac', 'marine', 'medivac', 'marine', 'marine', 'marine', 'marauder', 'marauder', 'medivac', 'marauder', 'marine'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "terran_20_vs_20_test1":{
        "n_units": 20,
        "n_enemies": 20,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['medivac', 'marine', 'medivac', 'medivac', 'marine', 'marauder', 'medivac', 'medivac', 'medivac', 'medivac', 'marine', 'marine', 'marine', 'marine', 'marine', 'medivac', 'medivac', 'marauder', 'marine', 'medivac'],
            "enemy_team": ['medivac', 'marine', 'medivac', 'medivac', 'marine', 'marauder', 'medivac', 'medivac', 'medivac', 'medivac', 'marine', 'marine', 'marine', 'marine', 'marine', 'medivac', 'medivac', 'marauder', 'marine', 'medivac'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "terran_20_vs_20_test2":{
        "n_units": 20,
        "n_enemies": 20,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['marauder', 'marauder', 'marine', 'marauder', 'medivac', 'marine', 'marine', 'medivac', 'medivac', 'marauder', 'marine', 'medivac', 'medivac', 'marauder', 'marauder', 'marine', 'marauder', 'marine', 'marauder', 'marauder'],
            "enemy_team": ['marauder', 'marauder', 'marine', 'marauder', 'medivac', 'marine', 'marine', 'medivac', 'medivac', 'marauder', 'marine', 'medivac', 'medivac', 'marauder', 'marauder', 'marine', 'marauder', 'marine', 'marauder', 'marauder'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "terran_20_vs_20_test3":{
        "n_units": 20,
        "n_enemies": 20,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['medivac', 'medivac', 'marine', 'marine', 'marauder', 'marine', 'marauder', 'medivac', 'marine', 'marauder', 'marauder', 'medivac', 'marine', 'marauder', 'medivac', 'marauder', 'medivac', 'medivac', 'marauder', 'marauder'],
            "enemy_team": ['medivac', 'medivac', 'marine', 'marine', 'marauder', 'marine', 'marauder', 'medivac', 'marine', 'marauder', 'marauder', 'medivac', 'marine', 'marauder', 'medivac', 'marauder', 'medivac', 'medivac', 'marauder', 'marauder'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "terran_20_vs_23_test1":{
        "n_units": 20,
        "n_enemies": 23,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['medivac', 'marauder', 'marine', 'marauder', 'marine', 'medivac', 'medivac', 'medivac', 'marine', 'medivac', 'medivac', 'marauder', 'marauder', 'marauder', 'marauder', 'marine', 'marauder', 'marauder', 'marauder', 'medivac'],
            "enemy_team": ['medivac', 'marauder', 'marine', 'marauder', 'marine', 'medivac', 'medivac', 'medivac', 'marine', 'medivac', 'medivac', 'marauder', 'marauder', 'marauder', 'marauder', 'marine', 'marauder', 'marauder', 'marauder', 'medivac', 'marauder', 'marine', 'marine'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "terran_20_vs_23_test2":{
        "n_units": 20,
        "n_enemies": 23,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['marauder', 'medivac', 'marine', 'marine', 'medivac', 'marine', 'marauder', 'medivac', 'marine', 'medivac', 'medivac', 'medivac', 'marine', 'medivac', 'marauder', 'marauder', 'marine', 'marauder', 'marine', 'medivac'],
            "enemy_team": ['marauder', 'medivac', 'marine', 'marine', 'medivac', 'marine', 'marauder', 'medivac', 'marine', 'medivac', 'medivac', 'medivac', 'marine', 'medivac', 'marauder', 'marauder', 'marine', 'marauder', 'marine', 'medivac', 'marauder', 'marauder', 'medivac'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "terran_20_vs_23_test3":{
        "n_units": 20,
        "n_enemies": 23,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['marine', 'marine', 'medivac', 'marine', 'medivac', 'marauder', 'medivac', 'marine', 'marine', 'marauder', 'medivac', 'medivac', 'marauder', 'medivac', 'marauder', 'marauder', 'marauder', 'medivac', 'marine', 'marine'],
            "enemy_team": ['marine', 'marine', 'medivac', 'marine', 'medivac', 'marauder', 'medivac', 'marine', 'marine', 'marauder', 'medivac', 'medivac', 'marauder', 'medivac', 'marauder', 'marauder', 'marauder', 'medivac', 'marine', 'marine', 'medivac', 'marauder', 'marine'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
}
task2team_meta_test_protoss = {
    "protoss_5_vs_5_test1":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['zealot', 'zealot', 'zealot', 'colossus', 'stalker'],
            "enemy_team": ['zealot', 'zealot', 'zealot', 'colossus', 'stalker'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "protoss_5_vs_5_test2":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['colossus', 'zealot', 'stalker', 'stalker', 'zealot'],
            "enemy_team": ['colossus', 'zealot', 'stalker', 'stalker', 'zealot'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "protoss_5_vs_5_test3":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['colossus', 'colossus', 'stalker', 'zealot', 'colossus'],
            "enemy_team": ['colossus', 'colossus', 'stalker', 'zealot', 'colossus'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "protoss_10_vs_10_test1":{
        "n_units": 10,
        "n_enemies": 10,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['colossus', 'zealot', 'stalker', 'zealot', 'stalker', 'zealot', 'colossus', 'colossus', 'stalker', 'stalker'],
            "enemy_team": ['colossus', 'zealot', 'stalker', 'zealot', 'stalker', 'zealot', 'colossus', 'colossus', 'stalker', 'stalker'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "protoss_10_vs_10_test2":{
        "n_units": 10,
        "n_enemies": 10,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['colossus', 'stalker', 'stalker', 'colossus', 'stalker', 'zealot', 'stalker', 'stalker', 'colossus', 'zealot'],
            "enemy_team": ['colossus', 'stalker', 'stalker', 'colossus', 'stalker', 'zealot', 'stalker', 'stalker', 'colossus', 'zealot'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "protoss_10_vs_10_test3":{
        "n_units": 10,
        "n_enemies": 10,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['colossus', 'zealot', 'zealot', 'colossus', 'zealot', 'zealot', 'zealot', 'stalker', 'stalker', 'colossus'],
            "enemy_team": ['colossus', 'zealot', 'zealot', 'colossus', 'zealot', 'zealot', 'zealot', 'stalker', 'stalker', 'colossus'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "protoss_10_vs_11_test1":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['stalker', 'stalker', 'stalker', 'colossus', 'colossus', 'colossus', 'zealot', 'colossus', 'zealot', 'colossus'],
            "enemy_team": ['stalker', 'stalker', 'stalker', 'colossus', 'colossus', 'colossus', 'zealot', 'colossus', 'zealot', 'colossus', 'zealot'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "protoss_10_vs_11_test2":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['zealot', 'stalker', 'zealot', 'stalker', 'colossus', 'zealot', 'zealot', 'zealot', 'zealot', 'zealot'],
            "enemy_team": ['zealot', 'stalker', 'zealot', 'stalker', 'colossus', 'zealot', 'zealot', 'zealot', 'zealot', 'zealot', 'colossus'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "protoss_10_vs_11_test3":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['zealot', 'stalker', 'zealot', 'stalker', 'zealot', 'colossus', 'stalker', 'stalker', 'stalker', 'zealot'],
            "enemy_team": ['zealot', 'stalker', 'zealot', 'stalker', 'zealot', 'colossus', 'stalker', 'stalker', 'stalker', 'zealot', 'stalker'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "protoss_20_vs_20_test1":{
        "n_units": 20,
        "n_enemies": 20,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['zealot', 'stalker', 'zealot', 'colossus', 'colossus', 'zealot', 'stalker', 'colossus', 'zealot', 'colossus', 'colossus', 'stalker', 'colossus', 'colossus', 'colossus', 'stalker', 'colossus', 'zealot', 'zealot', 'stalker'],
            "enemy_team": ['zealot', 'stalker', 'zealot', 'colossus', 'colossus', 'zealot', 'stalker', 'colossus', 'zealot', 'colossus', 'colossus', 'stalker', 'colossus', 'colossus', 'colossus', 'stalker', 'colossus', 'zealot', 'zealot', 'stalker'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "protoss_20_vs_20_test2":{
        "n_units": 20,
        "n_enemies": 20,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['zealot', 'zealot', 'zealot', 'zealot', 'stalker', 'stalker', 'zealot', 'zealot', 'stalker', 'zealot', 'colossus', 'stalker', 'colossus', 'colossus', 'colossus', 'zealot', 'colossus', 'colossus', 'stalker', 'colossus'],
            "enemy_team": ['zealot', 'zealot', 'zealot', 'zealot', 'stalker', 'stalker', 'zealot', 'zealot', 'stalker', 'zealot', 'colossus', 'stalker', 'colossus', 'colossus', 'colossus', 'zealot', 'colossus', 'colossus', 'stalker', 'colossus'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "protoss_20_vs_20_test3":{
        "n_units": 20,
        "n_enemies": 20,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['stalker', 'colossus', 'colossus', 'stalker', 'stalker', 'stalker', 'zealot', 'colossus', 'stalker', 'zealot', 'stalker', 'zealot', 'stalker', 'zealot', 'zealot', 'zealot', 'stalker', 'colossus', 'colossus', 'zealot'],
            "enemy_team": ['stalker', 'colossus', 'colossus', 'stalker', 'stalker', 'stalker', 'zealot', 'colossus', 'stalker', 'zealot', 'stalker', 'zealot', 'stalker', 'zealot', 'zealot', 'zealot', 'stalker', 'colossus', 'colossus', 'zealot'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "protoss_20_vs_23_test1":{
        "n_units": 20,
        "n_enemies": 23,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['zealot', 'zealot', 'zealot', 'zealot', 'zealot', 'stalker', 'colossus', 'stalker', 'colossus', 'zealot', 'zealot', 'stalker', 'zealot', 'colossus', 'stalker', 'colossus', 'stalker', 'zealot', 'zealot', 'colossus'],
            "enemy_team": ['zealot', 'zealot', 'zealot', 'zealot', 'zealot', 'stalker', 'colossus', 'stalker', 'colossus', 'zealot', 'zealot', 'stalker', 'zealot', 'colossus', 'stalker', 'colossus', 'stalker', 'zealot', 'zealot', 'colossus', 'zealot', 'colossus', 'colossus'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "protoss_20_vs_23_test2":{
        "n_units": 20,
        "n_enemies": 23,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['colossus', 'stalker', 'zealot', 'stalker', 'colossus', 'zealot', 'stalker', 'stalker', 'zealot', 'zealot', 'zealot', 'stalker', 'zealot', 'colossus', 'stalker', 'colossus', 'colossus', 'stalker', 'colossus', 'colossus'],
            "enemy_team": ['colossus', 'stalker', 'zealot', 'stalker', 'colossus', 'zealot', 'stalker', 'stalker', 'zealot', 'zealot', 'zealot', 'stalker', 'zealot', 'colossus', 'stalker', 'colossus', 'colossus', 'stalker', 'colossus', 'colossus', 'colossus', 'stalker', 'colossus'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "protoss_20_vs_23_test3":{
        "n_units": 20,
        "n_enemies": 23,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['zealot', 'zealot', 'stalker', 'stalker', 'colossus', 'colossus', 'stalker', 'colossus', 'zealot', 'colossus', 'stalker', 'stalker', 'stalker', 'colossus', 'stalker', 'stalker', 'stalker', 'zealot', 'colossus', 'zealot'],
            "enemy_team": ['zealot', 'zealot', 'stalker', 'stalker', 'colossus', 'colossus', 'stalker', 'colossus', 'zealot', 'colossus', 'stalker', 'stalker', 'stalker', 'colossus', 'stalker', 'stalker', 'stalker', 'zealot', 'colossus', 'zealot', 'colossus', 'stalker', 'zealot'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
}
task2team_meta_test_zerg = {
    "zerg_5_vs_5_test1":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['baneling', 'zergling', 'zergling', 'baneling', 'hydralisk'],
            "enemy_team": ['baneling', 'zergling', 'zergling', 'baneling', 'hydralisk'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "zerg_5_vs_5_test2":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['zergling', 'hydralisk', 'hydralisk', 'baneling', 'hydralisk'],
            "enemy_team": ['zergling', 'hydralisk', 'hydralisk', 'baneling', 'hydralisk'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "zerg_5_vs_5_test3":{
        "n_units": 5,
        "n_enemies": 5,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['zergling', 'hydralisk', 'zergling', 'baneling', 'zergling'],
            "enemy_team": ['zergling', 'hydralisk', 'zergling', 'baneling', 'zergling'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "zerg_10_vs_10_test1":{
        "n_units": 10,
        "n_enemies": 10,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['hydralisk', 'hydralisk', 'baneling', 'hydralisk', 'zergling', 'baneling', 'zergling', 'hydralisk', 'zergling', 'hydralisk'],
            "enemy_team": ['hydralisk', 'hydralisk', 'baneling', 'hydralisk', 'zergling', 'baneling', 'zergling', 'hydralisk', 'zergling', 'hydralisk'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "zerg_10_vs_10_test2":{
        "n_units": 10,
        "n_enemies": 10,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['hydralisk', 'zergling', 'hydralisk', 'baneling', 'zergling', 'zergling', 'baneling', 'baneling', 'zergling', 'hydralisk'],
            "enemy_team": ['hydralisk', 'zergling', 'hydralisk', 'baneling', 'zergling', 'zergling', 'baneling', 'baneling', 'zergling', 'hydralisk'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "zerg_10_vs_10_test3":{
        "n_units": 10,
        "n_enemies": 10,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['baneling', 'hydralisk', 'baneling', 'baneling', 'zergling', 'zergling', 'zergling', 'baneling', 'baneling', 'zergling'],
            "enemy_team": ['baneling', 'hydralisk', 'baneling', 'baneling', 'zergling', 'zergling', 'zergling', 'baneling', 'baneling', 'zergling'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "zerg_10_vs_11_test1":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['baneling', 'baneling', 'zergling', 'zergling', 'hydralisk', 'baneling', 'baneling', 'zergling', 'zergling', 'hydralisk'],
            "enemy_team": ['baneling', 'baneling', 'zergling', 'zergling', 'hydralisk', 'baneling', 'baneling', 'zergling', 'zergling', 'hydralisk', 'zergling'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "zerg_10_vs_11_test2":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['hydralisk', 'baneling', 'baneling', 'zergling', 'hydralisk', 'hydralisk', 'hydralisk', 'baneling', 'baneling', 'zergling'],
            "enemy_team": ['hydralisk', 'baneling', 'baneling', 'zergling', 'hydralisk', 'hydralisk', 'hydralisk', 'baneling', 'baneling', 'zergling', 'hydralisk'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "zerg_10_vs_11_test3":{
        "n_units": 10,
        "n_enemies": 11,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['zergling', 'baneling', 'hydralisk', 'zergling', 'hydralisk', 'zergling', 'baneling', 'baneling', 'hydralisk', 'hydralisk'],
            "enemy_team": ['zergling', 'baneling', 'hydralisk', 'zergling', 'hydralisk', 'zergling', 'baneling', 'baneling', 'hydralisk', 'hydralisk', 'baneling'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "zerg_20_vs_20_test1":{
        "n_units": 20,
        "n_enemies": 20,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team":  ['baneling', 'baneling', 'zergling', 'zergling', 'zergling', 'hydralisk', 'baneling', 'zergling', 'baneling', 'baneling', 'zergling', 'zergling', 'zergling', 'zergling', 'baneling', 'zergling', 'hydralisk', 'hydralisk', 'zergling', 'baneling'],
            "enemy_team":  ['baneling', 'baneling', 'zergling', 'zergling', 'zergling', 'hydralisk', 'baneling', 'zergling', 'baneling', 'baneling', 'zergling', 'zergling', 'zergling', 'zergling', 'baneling', 'zergling', 'hydralisk', 'hydralisk', 'zergling', 'baneling'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "zerg_20_vs_20_test2":{
        "n_units": 20,
        "n_enemies": 20,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['hydralisk', 'hydralisk', 'zergling', 'zergling', 'hydralisk', 'zergling', 'zergling', 'hydralisk', 'hydralisk', 'zergling', 'hydralisk', 'zergling', 'zergling', 'hydralisk', 'baneling', 'hydralisk', 'baneling', 'hydralisk', 'hydralisk', 'hydralisk'],
            "enemy_team": ['hydralisk', 'hydralisk', 'zergling', 'zergling', 'hydralisk', 'zergling', 'zergling', 'hydralisk', 'hydralisk', 'zergling', 'hydralisk', 'zergling', 'zergling', 'hydralisk', 'baneling', 'hydralisk', 'baneling', 'hydralisk', 'hydralisk', 'hydralisk'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "zerg_20_vs_20_test3":{
        "n_units": 20,
        "n_enemies": 20,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['zergling', 'baneling', 'hydralisk', 'baneling', 'zergling', 'hydralisk', 'baneling', 'baneling', 'zergling', 'hydralisk', 'zergling', 'hydralisk', 'baneling', 'baneling', 'zergling', 'hydralisk', 'hydralisk', 'baneling', 'baneling', 'baneling'],
            "enemy_team": ['zergling', 'baneling', 'hydralisk', 'baneling', 'zergling', 'hydralisk', 'baneling', 'baneling', 'zergling', 'hydralisk', 'zergling', 'hydralisk', 'baneling', 'baneling', 'zergling', 'hydralisk', 'hydralisk', 'baneling', 'baneling', 'baneling'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "zerg_20_vs_23_test1":{
        "n_units": 20,
        "n_enemies": 23,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['baneling', 'hydralisk', 'zergling', 'hydralisk', 'hydralisk', 'baneling', 'baneling', 'zergling', 'baneling', 'zergling', 'baneling', 'baneling', 'hydralisk', 'hydralisk', 'zergling', 'baneling', 'hydralisk', 'zergling', 'baneling', 'zergling'],
            "enemy_team": ['baneling', 'hydralisk', 'zergling', 'hydralisk', 'hydralisk', 'baneling', 'baneling', 'zergling', 'baneling', 'zergling', 'baneling', 'baneling', 'hydralisk', 'hydralisk', 'zergling', 'baneling', 'hydralisk', 'zergling', 'baneling', 'zergling', 'zergling', 'hydralisk', 'zergling'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "zerg_20_vs_23_test2":{
        "n_units": 20,
        "n_enemies": 23,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['hydralisk', 'zergling', 'baneling', 'zergling', 'zergling', 'zergling', 'baneling', 'zergling', 'baneling', 'baneling', 'baneling', 'baneling', 'hydralisk', 'hydralisk', 'hydralisk', 'baneling', 'zergling', 'baneling', 'baneling', 'zergling'],
            "enemy_team": ['hydralisk', 'zergling', 'baneling', 'zergling', 'zergling', 'zergling', 'baneling', 'zergling', 'baneling', 'baneling', 'baneling', 'baneling', 'hydralisk', 'hydralisk', 'hydralisk', 'baneling', 'zergling', 'baneling', 'baneling', 'zergling', 'zergling', 'hydralisk', 'zergling'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
    "zerg_20_vs_23_test3":{
        "n_units": 20,
        "n_enemies": 23,
        "team_gen": {
            "dist_type": "fixed_given",
            "ally_team": ['baneling', 'hydralisk', 'hydralisk', 'hydralisk', 'baneling', 'baneling', 'hydralisk', 'hydralisk', 'hydralisk', 'hydralisk', 'zergling', 'baneling', 'hydralisk', 'zergling', 'zergling', 'zergling', 'zergling', 'zergling', 'hydralisk', 'baneling'],
            "enemy_team": ['baneling', 'hydralisk', 'hydralisk', 'hydralisk', 'baneling', 'baneling', 'hydralisk', 'hydralisk', 'hydralisk', 'hydralisk', 'zergling', 'baneling', 'hydralisk', 'zergling', 'zergling', 'zergling', 'zergling', 'zergling', 'hydralisk', 'baneling', 'zergling', 'baneling', 'hydralisk'],
            "observe": True,
        },
        "start_positions": share_start_position,
    },
}
task2team = task2team_base | task2team_meta_train_terran | task2team_meta_train_protoss | task2team_meta_train_zerg |task2team_meta_test_terran | task2team_meta_test_protoss | task2team_meta_test_zerg


class StarCraftMetaCapabilityEnvWrapper(MultiAgentEnv):
    def __init__(self, **kwargs):
        # self.distribution_config = kwargs["capability_config"]
        self.task = kwargs.pop("task")
        self.distribution_config = task2team[self.task]
        kwargs["capability_config"] = task2team[self.task]
        self.env_key_to_distribution_map = {}
        self._parse_distribution_config()
        self.env = StarCraft2Env(**kwargs)
        # assert (
        #     self.distribution_config.keys()
        #     == kwargs["capability_config"].keys()
        # ), "Must give distribution config and capability config the same keys"

    def _parse_distribution_config(self):
        for env_key, config in self.distribution_config.items():
            if env_key == "n_units" or env_key == "n_enemies":
                continue
            config["env_key"] = env_key
            # add n_units key
            config["n_units"] = self.distribution_config["n_units"]
            config["n_enemies"] = self.distribution_config["n_enemies"]
            distribution = get_distribution(config["dist_type"])(config)
            self.env_key_to_distribution_map[env_key] = distribution

    def reset(self):
        try:
            reset_config = {}
            for distribution in self.env_key_to_distribution_map.values():
                reset_config = {**reset_config, **distribution.generate()}

            return self.env.reset(reset_config)
        except CannotResetException as cre:
            # just retry
            self.reset()

    def __getattr__(self, name):
        if hasattr(self.env, name):
            return getattr(self.env, name)
        else:
            raise AttributeError

    def get_obs(self):
        return self.env.get_obs()

    def get_obs_feature_names(self):
        return self.env.get_obs_feature_names()

    def get_state(self):
        return self.env.get_state()

    def get_state_feature_names(self):
        return self.env.get_state_feature_names()

    def get_avail_actions(self):
        return self.env.get_avail_actions()

    def get_env_info(self):
        return self.env.get_env_info()

    def get_obs_size(self):
        return self.env.get_obs_size()

    def get_state_size(self):
        return self.env.get_state_size()

    def get_total_actions(self):
        return self.env.get_total_actions()

    def get_capabilities(self):
        return self.env.get_capabilities()

    def get_obs_agent(self, agent_id):
        return self.env.get_obs_agent(agent_id)

    def get_avail_agent_actions(self, agent_id):
        return self.env.get_avail_agent_actions(agent_id)

    def render(self, mode="human"):
        return self.env.render(mode=mode)

    def step(self, actions):
        return self.env.step(actions)

    def get_stats(self):
        return self.env.get_stats()

    def full_restart(self):
        return self.env.full_restart()

    def save_replay(self):
        self.env.save_replay()

    def close(self):
        return self.env.close()




class StarCraftCapabilityEnvWrapper(MultiAgentEnv):
    def __init__(self, **kwargs):
        self.distribution_config = kwargs["capability_config"]
        self.env_key_to_distribution_map = {}
        self._parse_distribution_config()
        self.env = StarCraft2Env(**kwargs)
        assert (
            self.distribution_config.keys()
            == kwargs["capability_config"].keys()
        ), "Must give distribution config and capability config the same keys"

    def _parse_distribution_config(self):
        for env_key, config in self.distribution_config.items():
            if env_key == "n_units" or env_key == "n_enemies":
                continue
            config["env_key"] = env_key
            # add n_units key
            config["n_units"] = self.distribution_config["n_units"]
            config["n_enemies"] = self.distribution_config["n_enemies"]
            distribution = get_distribution(config["dist_type"])(config)
            self.env_key_to_distribution_map[env_key] = distribution

    def reset(self):
        try:
            reset_config = {}
            for distribution in self.env_key_to_distribution_map.values():
                reset_config = {**reset_config, **distribution.generate()}

            return self.env.reset(reset_config)
        except CannotResetException as cre:
            # just retry
            self.reset()

    def __getattr__(self, name):
        if hasattr(self.env, name):
            return getattr(self.env, name)
        else:
            raise AttributeError

    def get_obs(self):
        return self.env.get_obs()

    def get_obs_feature_names(self):
        return self.env.get_obs_feature_names()

    def get_state(self):
        return self.env.get_state()

    def get_state_feature_names(self):
        return self.env.get_state_feature_names()

    def get_avail_actions(self):
        return self.env.get_avail_actions()

    def get_env_info(self):
        return self.env.get_env_info()

    def get_obs_size(self):
        return self.env.get_obs_size()

    def get_state_size(self):
        return self.env.get_state_size()

    def get_total_actions(self):
        return self.env.get_total_actions()

    def get_capabilities(self):
        return self.env.get_capabilities()

    def get_obs_agent(self, agent_id):
        return self.env.get_obs_agent(agent_id)

    def get_avail_agent_actions(self, agent_id):
        return self.env.get_avail_agent_actions(agent_id)

    def render(self, mode="human"):
        return self.env.render(mode=mode)

    def step(self, actions):
        return self.env.step(actions)

    def get_stats(self):
        return self.env.get_stats()

    def full_restart(self):
        return self.env.full_restart()

    def save_replay(self):
        self.env.save_replay()

    def close(self):
        return self.env.close()
