import random

import grpc
import numpy as np
import torch
import gfootball.env as football_env

from expground.logger import Log


NUM_ACTIONS = 19


def random_actions(obs):
    num_players = 1 if len(obs.shape) == 3 else obs.shape[0]
    a = []
    for _ in range(num_players):
        a.append(random.randint(0, NUM_ACTIONS - 1))
    return a


def preprocessing(observation):
    observation = np.expand_dims(observation, axis=0)
    data = np.packbits(observation, axis=-1)  # This packs to uint8
    if data.shape[-1] % 2 == 1:
        data = np.pad(data, [(0, 0)] * (data.ndim - 1) + [(0, 1)], "constant")
    return data.view(np.uint16)


def generate_actions(obs, model):
    a = []
    # Single agent case
    if len(obs.shape) == 3:
        a.append(model(preprocessing(obs))[0][0].numpy())
    else:
        # Multiagent -> first dimension is a number of agents you control.
        for x in range(obs.shape[0]):
            a.append(model(preprocessing(obs[x]))[0][0].numpy())
    return a


def get_inference_model(inference_model: str, random_act: bool = False):
    if not inference_model or random_act:
        return random_actions
    # torch load
    model = tf.saved_model.load(inference_model)
    return lambda obs: generate_actions(obs, model)


# def main(inference_model):
#     model = get_inference_model(inference_model)
#     env = football_env.create_remote_environment(
#         FLAGS.username,
#         FLAGS.token,
#         FLAGS.model_name,
#         track=FLAGS.track,
#         representation="extracted",
#         stacked=True,
#         include_rendering=FLAGS.render,
#     )
#     for _ in range(FLAGS.how_many):
#         ob = env.reset()
#         cnt = 1
#         done = False
#         while not done:
#             try:
#                 action = model(ob)
#                 ob, rew, done, _ = env.step(action)
#                 Log.info(
#                     "Playing the game, step %d, action %s, rew %s, done %d",
#                     cnt,
#                     action,
#                     rew,
#                     done,
#                 )
#                 cnt += 1
#             except grpc.RpcError as e:
#                 print(e)
#                 break
#         print("=" * 50)
