
import hydra
from omegaconf import DictConfig
import torch

from il_scale.nethack.utils.model import count_params
from il_scale.nethack.agent import Agent

@hydra.main(version_base=None, config_path="../../../conf", config_name="nethack_config")
def main(cfg: DictConfig) -> None:
    # # MAMBA 20M
    # cfg.network.core_mode = "mamba"
    # cfg.network.hdim = 514
    # cfg.network.mamba_num_layers = 11

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)
    # print(agent.model)
    # breakpoint()

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.mamba_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print()

    # cfg.network.hdim = 252
    # cfg.network.mamba_num_layers = 46

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.mamba_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print()

    # cfg.network.hdim = 383
    # cfg.network.mamba_num_layers = 20

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.mamba_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print()

    # cfg.network.hdim = 752
    # cfg.network.mamba_num_layers = 5

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.mamba_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print()

    # cfg.network.hdim = 1145
    # cfg.network.mamba_num_layers = 2

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.mamba_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print()

    # MAMBA 50M
    # cfg.network.core_mode = "mamba"
    # cfg.network.hdim = 511
    # cfg.network.mamba_num_layers = 29

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.mamba_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print()

    # cfg.network.core_mode = "mamba"
    # cfg.network.hdim = 256
    # cfg.network.mamba_num_layers = 113

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.mamba_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print(f"Diff: {count_params(agent.model) - (count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer))}")
    # print()

    # cfg.network.core_mode = "mamba"
    # cfg.network.hdim = 384
    # cfg.network.mamba_num_layers = 51

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.mamba_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print(f"Diff: {count_params(agent.model) - (count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer))}")
    # print()

    # cfg.network.core_mode = "mamba"
    # cfg.network.hdim = 760
    # cfg.network.mamba_num_layers = 13

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.mamba_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print(f"Diff: {count_params(agent.model) - (count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer))}")
    # print()

    # cfg.network.core_mode = "mamba"
    # cfg.network.hdim = 1030
    # cfg.network.mamba_num_layers = 7

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.mamba_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print(f"Diff: {count_params(agent.model) - (count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer))}")
    # print()


    # MAMBA 100M
    # cfg.network.core_mode = "mamba"
    # cfg.network.hdim = 1046
    # cfg.network.mamba_num_layers = 14

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.mamba_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print(f"Diff: {count_params(agent.model) - (count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer))}")
    # print()

    # cfg.network.core_mode = "mamba"
    # cfg.network.hdim = 514
    # cfg.network.mamba_num_layers = 58

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.mamba_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print(f"Diff: {count_params(agent.model) - (count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer))}")
    # print()

    # cfg.network.core_mode = "mamba"
    # cfg.network.hdim = 386
    # cfg.network.mamba_num_layers = 102

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.mamba_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print(f"Diff: {count_params(agent.model) - (count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer))}")
    # print()

    # cfg.network.core_mode = "mamba"
    # cfg.network.hdim = 768
    # cfg.network.mamba_num_layers = 26

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.mamba_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print(f"Diff: {count_params(agent.model) - (count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer))}")
    # print()

    # LLAMA ~22M
    # cfg.network.core_mode = "llama"
    # cfg.network.hdim = 512
    # cfg.network.tf_num_layers = 5
    # cfg.network.tf_num_heads = cfg.network.hdim // 64

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)
    # print(agent.model)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.tf_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print()

    # cfg.network.core_mode = "llama"
    # cfg.network.hdim = 256
    # cfg.network.tf_num_layers = 20
    # cfg.network.tf_num_heads = cfg.network.hdim // 64

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)
    # print(agent.model)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.tf_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print()

    # cfg.network.core_mode = "llama"
    # cfg.network.hdim = 384
    # cfg.network.tf_num_layers = 9
    # cfg.network.tf_num_heads = cfg.network.hdim // 64

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)
    # print(agent.model)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.tf_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print()

    # cfg.network.core_mode = "llama"
    # cfg.network.hdim = 768
    # cfg.network.tf_num_layers = 2
    # cfg.network.tf_num_heads = cfg.network.hdim // 64

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.tf_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print()

    # LLAMA 50M
    # cfg.network.core_mode = "llama"
    # cfg.network.hdim = 512
    # cfg.network.tf_num_layers = 12
    # cfg.network.tf_num_heads = cfg.network.hdim // 64

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.tf_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print()

    # cfg.network.core_mode = "llama"
    # cfg.network.hdim = 384
    # cfg.network.tf_num_layers = 21
    # cfg.network.tf_num_heads = cfg.network.hdim // 64

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.tf_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print()

    # cfg.network.core_mode = "llama"
    # cfg.network.hdim = 256
    # cfg.network.tf_num_layers = 48
    # cfg.network.tf_num_heads = cfg.network.hdim // 64

    # agent = Agent(cfg, None)
    # agent.construct_model(cfg)

    # print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.tf_num_layers}")
    # print(f"Model size: {count_params(agent.model)}")
    # print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    # print()

    cfg.network.core_mode = "llama"
    cfg.network.hdim = 704
    cfg.network.tf_num_layers = 7
    cfg.network.tf_num_heads = cfg.network.hdim // 64

    assert cfg.network.hdim % 64 == 0

    agent = Agent(cfg, None)
    agent.construct_model(cfg)

    print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.tf_num_layers}, HEADS: {cfg.network.tf_num_heads}")
    print(f"Model size: {count_params(agent.model)}")
    print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    print()

    # hdim_layer_to_params = {}
    # for hdim in [128, 256, 384, 512, 640, 768, 896, 1024]:
    #     hdim_to_max_layer = {
    #         128: 60,
    #         256: 50,
    #         384: 40,
    #         512: 40,
    #         640: 40,
    #         768: 40,
    #         896: 30,
    #         1024: 30
    #     }
    #     for layer in range(hdim_to_max_layer[hdim]):
    #         cfg.network.core_mode = "llama"
    #         cfg.network.hdim = hdim
    #         cfg.network.tf_num_layers = layer
    #         cfg.network.tf_num_heads = cfg.network.hdim // 64

    #         agent = Agent(cfg, None)
    #         agent.construct_model(cfg)

    #         print(f"HDIM: {cfg.network.hdim}, LAYERS: {cfg.network.tf_num_layers}")
    #         print(f"Model size: {count_params(agent.model)}")
    #         print(f"Core size: {count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)}")
    #         print()

    #         hdim_layer_to_params[(hdim, layer)] = count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)

    # # save 
    # print(hdim_layer_to_params)
    # torch.save(hdim_layer_to_params, "hdim_layer_to_params.tar")

    # hdim_layer_to_params = torch.load("hdim_layer_to_params.tar")

    # found_512_layers = []
    # found_640_layers = []
    # found_768_layers = []
    # found_896_layers = []
    # # First find settings for largest hdim models since they have least flexibility
    # for layer_512 in range(40):
    #     for layer_640 in range(40):
    #         for layer_768 in range(40):
    #             for layer_896 in range(30):
    #                 min_params = min(hdim_layer_to_params[(512, layer_512)], hdim_layer_to_params[(640, layer_640)], hdim_layer_to_params[(768, layer_768)], hdim_layer_to_params[(896, layer_896)])
    #                 max_params = max(hdim_layer_to_params[(512, layer_512)], hdim_layer_to_params[(640, layer_640)], hdim_layer_to_params[(768, layer_768)], hdim_layer_to_params[(896, layer_896)])
    #                 if max_params - min_params < 2.5e6 and min_params > 10e6:
    #                     found_512_layers.append(layer_512)
    #                     found_640_layers.append(layer_640)
    #                     found_768_layers.append(layer_768)
    #                     found_896_layers.append(layer_896)


    # # for layer_256 in range(50):
    # for layer_384 in range(70):
    #     for layer_512 in found_512_layers:
    #         for layer_640 in found_640_layers:
    #             for layer_768 in found_768_layers:
    #                 for layer_896 in found_896_layers:
    #                     min_params = min(hdim_layer_to_params[(384, layer_384)], hdim_layer_to_params[(512, layer_512)], hdim_layer_to_params[(640, layer_640)], hdim_layer_to_params[(768, layer_768)], hdim_layer_to_params[(896, layer_896)])
    #                     max_params = max(hdim_layer_to_params[(384, layer_384)], hdim_layer_to_params[(512, layer_512)], hdim_layer_to_params[(640, layer_640)], hdim_layer_to_params[(768, layer_768)], hdim_layer_to_params[(896, layer_896)])
    #                     if max_params - min_params < 2.5e6 and min_params > 10e6:
    #                         # print('256', layer_256, hdim_layer_to_params[(256, layer_256)])
    #                         print('384', layer_384, hdim_layer_to_params[(384, layer_384)])
    #                         print('512', layer_512, hdim_layer_to_params[(512, layer_512)])
    #                         print('640', layer_640, hdim_layer_to_params[(640, layer_640)])
    #                         print('768', layer_768, hdim_layer_to_params[(768, layer_768)])
    #                         print('896', layer_896, hdim_layer_to_params[(896, layer_896)])
    #                         print()

    print('done!')


if __name__ == "__main__":
    main()
