import eas
import numpy as np
import torch


def build_traverser(game):
    if game == "classical_phantom_ttt":
        return eas.PtttTraverser()
    elif game == "abrupt_phantom_ttt":
        return eas.AbruptPtttTraverser()
    elif game == "classical_dark_hex":
        return eas.DhTraverser()
    elif game == "abrupt_dark_hex":
        return eas.AbruptDhTraverser()
    else:
        raise ValueError(f"Not supported game: {game}")


def get_legality_indicies(game):
    if "phantom_ttt" in game:
        # cell i is empty (for the current player) iff infostate[i] == 1
        action_indices = torch.arange(9)
    elif "dark_hex" in game:
        # cell i is empty (for the current player) iff infostate[9*i+4] == 1
        action_indices = torch.tensor([4, 13, 22, 31, 40, 49, 58, 67, 76])
    else:
        raise ValueError(f"Not supported game: {game}")

    return action_indices


def compute_exploitability(
    model_p0,
    model_p1,
    traverser,
    batch_size=400_000,
    action_selection=["sto", "sto"],
    game_name="phantom_ttt",
):
    # models should directly output logits NOT probabilties
    # action selection determines if the actions should be chosen stochastically or deterministically w.r.t. the logits
    if isinstance(model_p0, dict) and isinstance(model_p1, dict):
        averagers = [
            traverser.new_averager(0, eas.AveragingStrategy(5)),
            traverser.new_averager(1, eas.AveragingStrategy(5)),
        ]
    elif isinstance(model_p0, dict) and not isinstance(model_p1, dict):
        averagers = [traverser.new_averager(0, eas.AveragingStrategy(5)), None]
    elif not isinstance(model_p0, dict) and isinstance(model_p1, dict):
        averagers = [None, traverser.new_averager(1, eas.AveragingStrategy(5))]
    else:
        averagers = [None, None]

    infostates_0 = traverser.compute_openspiel_infostates(0)
    infostates_1 = traverser.compute_openspiel_infostates(1)

    infostates = [infostates_0, infostates_1]
    models = [model_p0, model_p1]

    legality_indicies = get_legality_indicies(game_name)

    probs_batch = []
    for i, averager in enumerate(averagers):
        if averager is None:
            probs = build_probability_table(
                models[i],
                infostates[i],
                legality_indicies,
                action_selection[i],
                batch_size=batch_size,
            ).numpy()
        else:
            nonzero_weight_found = False
            for model, weight in zip(models[i]["models"], models[i]["weights"]):
                # get model probabilities for each model within psro
                intermediate_probs = build_probability_table(
                    model,
                    infostates[i],
                    legality_indicies,
                    action_selection[i],
                    batch_size=batch_size,
                )

                # ensure that first weight is not 0
                if weight < 1e-9 and not nonzero_weight_found:
                    continue
                elif weight > 0 and not nonzero_weight_found:
                    nonzero_weight_found = True

                # push probabilities to running averager
                averager.push(intermediate_probs, weight)

            # get probability table from eas
            probs = averager.running_avg()

        probs_batch.append(probs)

    # save some memory
    del infostates_0
    del infostates_1

    probs_0 = probs_batch[0]
    probs_1 = probs_batch[1]

    if not (np.max(np.abs(np.sum(probs_0, axis=1) - 1)) < 1e-6):
        print(
            f"[WARNING] probs_0 max diff is {np.max(np.abs(np.sum(probs_0, axis=1) - 1))}"
        )
    if not (np.max(np.abs(np.sum(probs_1, axis=1) - 1)) < 1e-6):
        print(
            f"[WARNING] probs_1 max diff is {np.max(np.abs(np.sum(probs_1, axis=1) - 1))}"
        )

    out = traverser.ev_and_exploitability(probs_0, probs_1)

    return out.ev0, out.expl[0], out.expl[1]


def compute_exploitability_cached(
    model_p0,
    model_p1,
    traverser,
    batch_size=400_000,
    action_selection=["sto", "sto"],
    game_name="phantom_ttt",
    probs_0=None,
    probs_1=None,
):
    # models should directly output logits NOT probabilties
    # action selection determines if the actions should be chosen stochastically or deterministically w.r.t. the logits
    if isinstance(model_p0, dict) and isinstance(model_p1, dict):
        averagers = [
            traverser.new_averager(0, eas.AveragingStrategy(5)),
            traverser.new_averager(1, eas.AveragingStrategy(5)),
        ]
    elif isinstance(model_p0, dict) and not isinstance(model_p1, dict):
        averagers = [traverser.new_averager(0, eas.AveragingStrategy(5)), None]
    elif not isinstance(model_p0, dict) and isinstance(model_p1, dict):
        averagers = [None, traverser.new_averager(1, eas.AveragingStrategy(5))]
    else:
        averagers = [None, None]

    infostates_0 = traverser.compute_openspiel_infostates(0)
    infostates_1 = traverser.compute_openspiel_infostates(1)

    infostates = [infostates_0, infostates_1]
    models = [model_p0, model_p1]

    legality_indicies = get_legality_indicies(game_name)
    cached_probs = [probs_0, probs_1]
    probs_batch = []
    for i, averager in enumerate(averagers):
        if cached_probs[i] is not None:
            probs = cached_probs[i]
        else:
            if averager is None:
                probs = build_probability_table(
                    models[i],
                    infostates[i],
                    legality_indicies,
                    action_selection[i],
                    batch_size=batch_size,
                ).numpy()
            else:
                nonzero_weight_found = False
                for model, weight in zip(models[i]["models"], models[i]["weights"]):
                    # get model probabilities for each model within psro
                    intermediate_probs = build_probability_table(
                        model,
                        infostates[i],
                        legality_indicies,
                        action_selection[i],
                        batch_size=batch_size,
                    )

                    # ensure that first weight is not 0
                    if weight < 1e-9 and not nonzero_weight_found:
                        continue
                    elif weight > 0 and not nonzero_weight_found:
                        nonzero_weight_found = True

                    # push probabilities to running averager
                    averager.push(intermediate_probs, weight)

                # get probability table from eas
                probs = averager.running_avg()

        probs_batch.append(probs)

    # save some memory
    del infostates_0
    del infostates_1

    probs_0 = probs_batch[0]
    probs_1 = probs_batch[1]

    if not (np.max(np.abs(np.sum(probs_0, axis=1) - 1)) < 1e-6):
        print(
            f"[WARNING] probs_0 max diff is {np.max(np.abs(np.sum(probs_0, axis=1) - 1))}"
        )
    if not (np.max(np.abs(np.sum(probs_1, axis=1) - 1)) < 1e-6):
        print(
            f"[WARNING] probs_1 max diff is {np.max(np.abs(np.sum(probs_1, axis=1) - 1))}"
        )

    out = traverser.ev_and_exploitability(probs_0, probs_1)

    return out.ev0, out.expl[0], out.expl[1], probs_0, probs_1


def build_probability_table(
    model, infostates, legality_indicies, action_selection, batch_size=400_000
):
    agent_probs = []
    k = 0
    while k * batch_size < infostates.shape[0]:
        infostates_batch = infostates[
            k * batch_size : min((k + 1) * batch_size, infostates.shape[0])
        ]
        # legal actions for a player correspond to empty cells for that player

        legal_actions_mask_batch = torch.tensor(infostates_batch[:, legality_indicies])
        with torch.no_grad():
            logits = model(
                torch.tensor(infostates_batch, dtype=torch.float32, requires_grad=False)
            )
        # loosely check if logits are actually logits
        assert (
            (torch.abs(logits.sum(dim=1) - 1) < 1e-6).sum() < batch_size / 2
        ), "Half the logits are close to a probability distribution"

        # mask out illegal actions
        logits = torch.where(legal_actions_mask_batch, logits, -1e9)

        # numerical stability
        logits -= logits.max(dim=-1, keepdim=True)[0]

        if action_selection == "sto":
            its = torch.exp(logits)
            probs = its / its.sum(dim=-1, keepdim=True)
        elif action_selection == "det":
            probs = torch.zeros_like(logits)
            probs[torch.arange(logits.shape[0]), logits.argmax(dim=1)] = 1.0
        else:
            raise ValueError(f"Unknown action selection method: {action_selection}")

        agent_probs.append(probs)
        k += 1
    return torch.cat(agent_probs, dim=0)
