"""Implements ESCHER.


The algorithm defines `regret`, `value`, and `average policy` networks.
The regret network is trained to estimate the cumulative regret for an infostate-action pair.
The policy at each timestep comes from performing regret matching on the estimated cumulative regret.
The value network is trained to estimate the value of a game under the current joint policy conditioned
on a history (state).
The average policy network is trained to estimate the average policy over all timesteps.
To train these networks we use three reservoir buffers, one for each network.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import collections
import contextlib
import os
import random
import numpy as np

# import tensorflow as tf
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from open_spiel.python import policy
import wandb

import pyspiel
import time

from algorithms.escher.escher_nets import PolicyNetwork, RegretNetwork, ValueNetwork
from algorithms.escher.buffer import ReservoirBuffer

# The size of the shuffle buffer used to reshuffle part of the data each
# epoch within one training iteration
REGRET_TRAIN_SHUFFLE_SIZE = 100000
VALUE_TRAIN_SHUFFLE_SIZE = 100000
AVERAGE_POLICY_TRAIN_SHUFFLE_SIZE = 1000000


class ESCHERSolver(policy.Policy):
    def __init__(
        self,
        game,
        policy_network_layers=(512, 512, 512),
        regret_network_layers=(512, 512, 512),
        value_network_layers=(512, 512, 512),
        num_iterations: int = 100,
        num_traversals: int = 130000,
        num_val_fn_traversals: int = 100,
        learning_rate: float = 1e-3,
        batch_size_regret: int = 10000,
        batch_size_value: int = 2024,
        batch_size_average_policy: int = 10000,
        markov_soccer: bool = False,
        phantom_ttt=False,
        dark_hex: bool = False,
        memory_capacity: int = int(1e5),
        policy_network_train_steps: int = 15000,
        regret_network_train_steps: int = 5000,
        value_network_train_steps: int = 4048,
        # check_exploitability_every: int = 20,
        reinitialize_regret_networks: bool = True,
        reinitialize_value_network: bool = True,
        save_regret_networks: str = None,
        append_legal_actions_mask=False,
        save_average_policy_memories: str = None,
        save_policy_weights=True,
        expl: float = 1.0,
        val_expl: float = 0.01,
        importance_sampling_threshold: float = 100.0,
        importance_sampling: bool = True,
        clear_value_buffer: bool = True,
        val_bootstrap=False,
        use_balanced_probs: bool = False,
        val_op_prob=0.0,
        debug_val=False,
        play_against_random=False,
        train_device="cpu",
        infer_device="cpu",
        experiment_string=None,
        all_actions=True,
        random_policy_path=None,
        use_wandb=None,
        num_random_games=None,
        max_steps=1e6,
        eval_every=1,
        *args,
        **kwargs,
    ):
        """Initialize the ESCHER algorithm.

        Args:
          game: Open Spiel game.
          policy_network_layers: (list[int]) Layer sizes of average_policy net MLP.
          regret_network_layers: (list[int]) Layer sizes of regret net MLP.
          value_network_layers: (list[int]) Layer sizes of value net MLP.
          num_iterations: Number of iterations.
          num_traversals: Number of traversals per iteration.
          num_val_fn_traversals: Number of history value function traversals per iteration
          learning_rate: Learning rate.
          batch_size_regret: (int) Batch size to sample from regret memories.
          batch_size_average_policy: (int) Batch size to sample from average_policy memories.
          memory_capacity: Number of samples that can be stored in memory.
          policy_network_train_steps: Number of policy network training steps (one
            policy training iteration at the end).
          regret_network_train_steps: Number of regret network training steps
            (per iteration).
          reinitialize_regret_networks: Whether to re-initialize the regret
            network before training on each iteration.
          save_regret_networks: If provided, all regret network itearations
            are saved in the given folder. This can be useful to implement SD-CFR
            https://arxiv.org/abs/1901.07621
          save_average_policy_memories: saves the collected average_policy memories as a
            tfrecords file in the given location. This is not affected by
            memory_capacity. All memories are saved to disk and not kept in memory
          infer_device: device used for TF-operations in the traversal branch.
            Format is anything accepted by tf.device
          train_device: device used for TF-operations in the NN training steps.
            Format is anything accepted by tf.device
        """
        all_players = list(range(game.num_players()))
        super().__init__(game, all_players)
        self._game = game
        self._save_policy_weights = save_policy_weights
        self._compute_exploitability = True
        self._dark_hex = dark_hex or self._game.get_type().short_name == "dark_hex"
        self._phantom_ttt = (
            phantom_ttt or self._game.get_type().short_name == "phantom_ttt"
        )
        self._play_against_random = play_against_random
        self._append_legal_actions_mask = append_legal_actions_mask
        self._num_random_games = num_random_games
        self._compute_exploitability = False
        self._play_against_random = True
        if game.get_type().dynamics == pyspiel.GameType.Dynamics.SIMULTANEOUS:
            # `_traverse_game_tree` does not take into account this option.
            raise ValueError("Simulatenous games are not supported.")
        self._batch_size_regret = batch_size_regret
        self._batch_size_value = batch_size_value
        self._batch_size_average_policy = batch_size_average_policy
        self._policy_network_train_steps = policy_network_train_steps
        self._regret_network_train_steps = regret_network_train_steps
        self._value_network_train_steps = value_network_train_steps
        self._policy_network_layers = policy_network_layers
        self._regret_network_layers = regret_network_layers
        self._value_network_layers = value_network_layers
        self._num_players = game.num_players()
        self._root_node = self._game.new_initial_state()
        self._eval_every = eval_every

        self._embedding_size = len(self._root_node.information_state_tensor(0))

        hist_state = np.append(
            self._root_node.information_state_tensor(0),
            self._root_node.information_state_tensor(1),
        )

        self._value_embedding_size = len(hist_state)
        self._num_iterations = num_iterations
        self._num_traversals = num_traversals
        self._num_val_fn_traversals = num_val_fn_traversals
        self._reinitialize_regret_networks = reinitialize_regret_networks
        self._reinit_value_network = reinitialize_value_network
        self._num_actions = game.num_distinct_actions()
        self._iteration = 1
        self._learning_rate = learning_rate
        self._save_regret_networks = save_regret_networks
        self._save_average_policy_memories = save_average_policy_memories
        self._infer_device = infer_device
        self._train_device = train_device
        self._memories_tfrecordpath = None
        self._memories_tfrecordfile = None
        # self._check_exploitability_every = check_exploitability_every
        self._expl = expl
        self._val_expl = val_expl
        self._importance_sampling = importance_sampling
        self._importance_sampling_threshold = importance_sampling_threshold
        self._clear_value_buffer = clear_value_buffer
        self._nodes_visited = 0
        self._games_played = 0
        self._example_info_state = [None, None]
        self._example_hist_state = None
        self._example_legal_actions_mask = [None, None]
        self._squared_errors = []
        self._squared_errors_child = []
        self._balanced_probs = {}
        self._use_balanced_probs = use_balanced_probs
        self._val_op_prob = val_op_prob
        self._val_bootstrap = val_bootstrap
        self._debug_val = debug_val
        self._experiment_string = experiment_string
        self._all_actions = all_actions
        self._random_policy_path = random_policy_path
        self._use_wandb = use_wandb
        self._max_steps = max_steps

        if self._save_regret_networks:
            os.makedirs(self._save_regret_networks, exist_ok=True)

        if self._save_average_policy_memories:
            os.makedirs(
                os.path.split(self._save_average_policy_memories)[0], exist_ok=True
            )
            self._memories_tfrecordpath = self._save_average_policy_memories

        # Initialize policy network, loss, optimizer
        self._reinitialize_policy_network()

        # Initialize regret networks, losses, optimizers
        self._regret_networks = []
        self._regret_networks_train = []
        self._loss_regrets = []

        def weighted_mse_loss(inputs, targets, weights=None):
            loss = (inputs - targets) ** 2
            if weights is not None:
                loss *= weights.unsqueeze(-1).expand_as(loss)
            loss = torch.mean(loss)
            return loss

        self._optimizer_regrets = []
        for player in range(self._num_players):
            self._regret_networks.append(
                RegretNetwork(
                    self._embedding_size,
                    self._regret_network_layers,
                    self._num_actions,
                    device=self._train_device,
                )
            )

            self._regret_networks_train.append(
                RegretNetwork(
                    self._embedding_size,
                    self._regret_network_layers,
                    self._num_actions,
                    device=self._train_device,
                )
            )
            self._loss_regrets.append(weighted_mse_loss)
            self._optimizer_regrets.append(
                torch.optim.Adam(
                    self._regret_networks_train[-1].parameters(), lr=learning_rate
                )
            )
            # self._regret_train_step.append(self._get_regret_train_graph(player))

        self._create_memories(memory_capacity)

        # Initialize value networks, losses, optimizers
        self._val_network = ValueNetwork(
            self._value_embedding_size,
            self._value_network_layers,
            device=self._train_device,
        )
        self._val_network_train = ValueNetwork(
            self._value_embedding_size,
            self._value_network_layers,
            device=self._train_device,
        )
        self._loss_value = torch.nn.MSELoss()  # tf.keras.losses.MeanSquaredError()
        self._optimizer_value = torch.optim.Adam(
            self._val_network_train.parameters(), lr=learning_rate
        )  # tf.keras.optimizers.Adam(learning_rate=learning_rate)
        # self._value_train_step = self._get_value_train_graph()
        # self._value_test_step = self._get_value_test_graph()

    def _reinitialize_policy_network(self):
        """Reinitalize policy network and optimizer for training."""
        self._policy_network = PolicyNetwork(
            self._embedding_size,
            self._policy_network_layers,
            self._num_actions,
            device=self._train_device,
        )
        self._policy_network.to(self._infer_device)
        self._optimizer_policy = torch.optim.Adam(
            self._policy_network.parameters(), lr=self._learning_rate
        )

        def weighted_mse_loss(inputs, targets, weights=None):
            loss = (inputs - targets) ** 2
            if weights is not None:
                loss *= weights.unsqueeze(-1).expand_as(loss)
            loss = torch.mean(loss)
            return loss

        self._loss_policy = weighted_mse_loss  # torch.nn.MSELoss()

    def _reinitialize_regret_network(self, player):
        """Reinitalize player's regret network and optimizer for training."""
        self._regret_networks_train[player] = RegretNetwork(
            self._embedding_size,
            self._regret_network_layers,
            self._num_actions,
            device=self._train_device,
        )
        for regret_net in self._regret_networks_train:
            regret_net.to(self._infer_device)
        self._optimizer_regrets[player] = torch.optim.Adam(
            self._regret_networks_train[player].parameters(), lr=self._learning_rate
        )

    def get_example_info_state(self, player):
        return self._example_info_state[player]

    def get_example_hist_state(self):
        return self._example_hist_state

    def get_example_legal_actions_mask(self, player):
        return self._example_legal_actions_mask[player]

    def _reinitialize_value_network(self):
        """Reinitalize player's value network and optimizer for training."""
        self._val_network_train = ValueNetwork(
            self._value_embedding_size,
            self._value_network_layers,
            device=self._train_device,
        )
        self._val_network_train.to(self._infer_device)
        self._optimizer_value = torch.optim.Adam(
            self._val_network_train.parameters(), lr=self._learning_rate
        )
        # self._value_train_step = (self._get_value_train_graph())

    @property
    def regret_buffers(self):
        return self._regret_memories

    @property
    def average_policy_buffer(self):
        return self._average_policy_memories

    def clear_regret_buffers(self):
        for p in range(self._num_players):
            self._regret_memories[p].clear()

    def _create_memories(self, memory_capacity):
        """Create memory buffers and associated feature descriptions."""
        self._average_policy_memories = ReservoirBuffer(memory_capacity)
        self._regret_memories = [
            ReservoirBuffer(memory_capacity) for _ in range(self._num_players)
        ]

        self._value_memory = ReservoirBuffer(memory_capacity)

        self._value_memory_test = ReservoirBuffer(memory_capacity)

    def get_val_weights(self):
        return self._val_network.get_weights()

    def set_val_weights(self, weights):
        self._val_network.set_weights(weights)

    def get_num_calls(self):
        num_calls = 0
        for p in range(self._num_players):
            num_calls += self._regret_memories[p].get_num_calls()
        print(num_calls)
        return num_calls

    def set_iteration(self, iteration):
        self._iteration = iteration

    def get_weights(self):
        regret_weights = [
            self._regret_networks[player].get_weights()
            for player in range(self._num_players)
        ]
        return regret_weights

    def get_policy_weights(self):
        policy_weights = self._policy_network.get_weights()
        return policy_weights

    def set_policy_weights(self, policy_weights):
        self._reinitialize_policy_network()
        self._policy_network.set_weights(policy_weights)

    def get_regret_memories(self, player):
        return self._regret_memories[player].get_data()

    def get_value_memory(self):
        return self._value_memory.get_data()

    def clear_value_memory(self):
        self._value_memory.clear()

    def get_value_memory_test(self):
        return self._value_memory_test.get_data()

    def get_average_policy_memories(self):
        return self._average_policy_memories.get_data()

    def get_num_nodes(self):
        return self._nodes_visited

    def get_squared_errors(self):
        return self._squared_errors

    def reset_squared_errors(self):
        self._squared_errors = []

    def get_squared_errors_child(self):
        return self._squared_errors_child

    def reset_squared_errors_child(self):
        self._squared_errors_child = []

    def clear_val_memories_test(self):
        self._value_memory_test.clear()

    def clear_val_memories(self):
        self._value_memory.clear()

    def traverse_game_tree_n_times(
        self,
        n,
        p,
        train_regret=False,
        train_value=False,
        track_mean_squares=True,
        on_policy_prob=0.0,
        expl=0.6,
        val_test=False,
    ):
        for i in range(n):
            if i > 0:
                track_mean_squares = False
            self._traverse_game_tree(
                self._root_node,
                p,
                my_reach=1.0,
                opp_reach=1.0,
                sample_reach=1.0,
                my_sample_reach=1.0,
                train_regret=train_regret,
                train_value=train_value,
                track_mean_squares=track_mean_squares,
                on_policy_prob=on_policy_prob,
                expl=expl,
                val_test=val_test,
            )

    def play_game_against_random(self, win_rate=False):
        # play one game per player
        reward = 0
        wins = 0
        for player in [0, 1]:
            state = self._game.new_initial_state()
            while not state.is_terminal():
                if state.is_chance_node():
                    outcomes, probs = zip(*state.chance_outcomes())
                    aidx = np.random.choice(range(len(outcomes)), p=probs)
                    action = outcomes[aidx]
                else:
                    cur_player = state.current_player()
                    legal_actions_mask = torch.as_tensor(
                        state.legal_actions_mask(cur_player), dtype=torch.float32
                    )
                    if (not self._phantom_ttt) and (not self._dark_hex):
                        obs = torch.as_tensor(
                            state.observation_tensor(), dtype=torch.float32
                        )
                    else:
                        obs = torch.as_tensor(
                            state.information_state_tensor(cur_player),
                            dtype=torch.float32,
                        )
                    if len(obs.shape) == 1:
                        obs = torch.unsqueeze(obs, axis=0)
                    if cur_player == player:
                        probs = self._policy_network(
                            (
                                obs.to(self._policy_network.device),
                                legal_actions_mask.to(self._policy_network.device),
                            ),
                        )
                        probs = probs.cpu().detach().numpy()[0]
                        probs /= probs.sum()
                        action = np.random.choice(
                            range(state.num_distinct_actions()), p=probs
                        )
                    elif cur_player == 1 - player:
                        action = random.choice(state.legal_actions())
                    else:
                        print("Got player ", str(cur_player))
                        break
                state.apply_action(action)
            reward += state.returns()[player]
            wins += state.returns()[player] > 0
        if win_rate:
            return wins
        else:
            return reward

    def play_n_games_against_random(self, n, win_rate=False):
        total_reward = 0
        for i in range(n):
            reward = self.play_game_against_random(win_rate)
            total_reward += reward
        return total_reward / (2 * n)

    def print_mse(self):
        # track MSE
        squared_errors = self.get_squared_errors()
        self.reset_squared_errors()
        squared_errors_child = self.get_squared_errors_child()
        self.reset_squared_errors_child()
        print(sum(squared_errors) / len(squared_errors), "Mean Squared Errors")
        print(
            sum(squared_errors_child) / len(squared_errors_child),
            "Mean Squared Errors Child",
        )

    def iteration(
        self,
        regret_losses,
        value_losses,
        timestr,
        avg_reward,
        avg_win_against_random,
        algo_start_time,
        save_path_convs=None,
    ):
        wandb_log = {}
        print(
            self._nodes_visited,
        )

        start = time.time()
        start_nodes = self._nodes_visited
        if self._experiment_string is not None:
            print(self._experiment_string)

        # init weights
        # self.init_regret_net()
        # self.init_val_net()

        # train val function
        self.traverse_game_tree_n_times(
            self._num_val_fn_traversals,
            0,
            train_value=True,
            track_mean_squares=False,
            on_policy_prob=self._val_op_prob,
            expl=self._val_expl,
        )

        val_traj_time = time.time()
        print(val_traj_time - start, "val trajectory time")
        # wandb_log.update({"value_traj_fps":self._nodes_visited/(val_traj_time - start)})

        if self._reinit_value_network:
            self._reinitialize_value_network()

        value_loss = self.learn_value_network()
        value_losses.append(value_loss[-1])
        wandb_log.update({"value_loss": value_loss[-1]})

        if self._clear_value_buffer:
            self.clear_val_memories_test()
            self.clear_val_memories()

        val_train_time = time.time()
        print(val_train_time - val_traj_time, "val train time")
        wandb_log.update({"value_train_time": val_train_time - val_traj_time})

        # train regret network
        for p in range(self._num_players):
            regret_start_time = time.time()

            self.traverse_game_tree_n_times(
                self._num_traversals,
                p,
                train_regret=True,
                track_mean_squares=False,
                expl=self._expl,
            )
            num_nodes = self.get_num_nodes()
            regret_traj_time = time.time()

            print(regret_traj_time - regret_start_time, "regret trajectory time")
            wandb_log.update(
                {
                    f"regret_traj_time_{p}": num_nodes
                    / (regret_traj_time - regret_start_time)
                }
            )

            if self._reinitialize_regret_networks:
                self._reinitialize_regret_network(p)

            regret_loss = self.learn_regret_network(p)
            regret_losses[p].append(regret_loss[-1])
            wandb_log.update({f"regret_loss_{p}": regret_loss[-1]})

            if self._save_regret_networks:
                os.makedirs(self._save_regret_networks, exist_ok=True)
                self._regret_networks[p].save(
                    os.path.join(
                        self._save_regret_networks,
                        f"regretnet_p{p}_it{self._iteration:04}",
                    )
                )
            print(time.time() - regret_traj_time, "regret train time")

        total_regret_time = time.time()
        print(total_regret_time - val_train_time, "total regret time")
        wandb_log.update({"total_regret_time": total_regret_time - regret_traj_time})

        # check exploitability
        self._iteration += 1
        print(
            "Iteration",
            self._iteration,
            "took",
            time.time() - start,
            "seconds and fps is: ",
            (self._nodes_visited - start_nodes) / (time.time() - start),
        )

        avg_reward = self.play_n_games_against_random(self._num_random_games)
        wandb_log.update({"global_step": self._nodes_visited})
        wandb_log.update({"avg_reward_on_random": avg_reward})
        wandb_log.update({"avg_win_against_random": avg_win_against_random})
        wandb_log.update({"fps": self._nodes_visited / (time.time() - algo_start_time)})
        if self._use_wandb:
            wandb.log(wandb_log)

        return (
            regret_losses,
            value_losses,
            timestr,
            avg_reward,
            avg_win_against_random,
            algo_start_time,
        )

    def pre_iteration(self):
        regret_losses = collections.defaultdict(list)
        value_losses = []
        timestr = "{:%Y_%m_%d_%H_%M_%S}".format(datetime.now())
        if self._use_balanced_probs:
            self._get_balanced_probs(self._root_node)
        self._policy_network.to(self._policy_network.device)
        avg_reward = self.play_n_games_against_random(self._num_random_games)
        avg_win_against_random = self.play_n_games_against_random(
            self._num_random_games, win_rate=True
        )
        algo_start_time = time.time()

        return (
            regret_losses,
            value_losses,
            timestr,
            avg_reward,
            avg_win_against_random,
            algo_start_time,
        )

    def solve(self, save_path_convs=None):
        """Solution logic for Deep CFR."""
        regret_losses = collections.defaultdict(list)
        value_losses = []
        if self._use_balanced_probs:
            self._get_balanced_probs(self._root_node)
        self._policy_network.to(self._policy_network.device)
        avg_reward = self.play_n_games_against_random(self._num_random_games)
        avg_win_against_random = self.play_n_games_against_random(
            self._num_random_games, win_rate=True
        )
        algo_start_time = time.time()
        with contextlib.ExitStack():
            convs = []
            nodes = []

            for i in range(self._num_iterations + 1):
                wandb_log = {}
                print(
                    self._nodes_visited,
                )
                if self._nodes_visited > self._max_steps:
                    break
                # print(i)

                start = time.time()
                start_nodes = self._nodes_visited
                if self._experiment_string is not None:
                    print(self._experiment_string)

                # init weights
                # self.init_regret_net()
                # self.init_val_net()

                # train val function
                self.traverse_game_tree_n_times(
                    self._num_val_fn_traversals,
                    0,
                    train_value=True,
                    track_mean_squares=False,
                    on_policy_prob=self._val_op_prob,
                    expl=self._val_expl,
                )

                val_traj_time = time.time()
                print(val_traj_time - start, "val trajectory time")
                # wandb_log.update({"value_traj_fps":self._nodes_visited/(val_traj_time - start)})

                if self._reinit_value_network:
                    self._reinitialize_value_network()

                value_loss = self.learn_value_network()
                value_losses.append(value_loss[-1])
                wandb_log.update({"value_loss": value_loss[-1]})

                if self._clear_value_buffer:
                    self.clear_val_memories_test()
                    self.clear_val_memories()

                val_train_time = time.time()
                print(val_train_time - val_traj_time, "val train time")
                wandb_log.update({"value_train_time": val_train_time - val_traj_time})

                # train regret network
                for p in range(self._num_players):
                    regret_start_time = time.time()

                    self.traverse_game_tree_n_times(
                        self._num_traversals,
                        p,
                        train_regret=True,
                        track_mean_squares=False,
                        expl=self._expl,
                    )
                    num_nodes = self.get_num_nodes()
                    regret_traj_time = time.time()

                    print(
                        regret_traj_time - regret_start_time, "regret trajectory time"
                    )
                    wandb_log.update(
                        {
                            f"regret_traj_time_{p}": num_nodes
                            / (regret_traj_time - regret_start_time)
                        }
                    )

                    if self._reinitialize_regret_networks:
                        self._reinitialize_regret_network(p)

                    regret_loss = self.learn_regret_network(p)
                    regret_losses[p].append(regret_loss[-1])
                    wandb_log.update({f"regret_loss_{p}": regret_loss[-1]})

                    if self._save_regret_networks:
                        os.makedirs(self._save_regret_networks, exist_ok=True)
                        self._regret_networks[p].save(
                            os.path.join(
                                self._save_regret_networks,
                                f"regretnet_p{p}_it{self._iteration:04}",
                            )
                        )
                    print(time.time() - regret_traj_time, "regret train time")

                total_regret_time = time.time()
                print(total_regret_time - val_train_time, "total regret time")
                wandb_log.update(
                    {"total_regret_time": total_regret_time - regret_traj_time}
                )

                # check exploitability
                self._iteration += 1
                print(
                    "Iteration",
                    i,
                    "took",
                    time.time() - start,
                    "seconds and fps is: ",
                    (self._nodes_visited - start_nodes) / (time.time() - start),
                )

                if i % self._eval_every == 0:
                    self._reinitialize_policy_network()
                    policy_loss = self.learn_average_policy_network()
                    if self._save_policy_weights:
                        save_path_model = save_path_convs
                        model_path = save_path_model + "/policy_nodes_" + str(num_nodes)
                        torch.save(self._policy_network.state_dict(), model_path)

                        print("saved policy to ", model_path)
                        self.save_policy_network(model_path + "full_model")
                        print("saved policy to ", model_path + "full_model")

                avg_reward = self.play_n_games_against_random(self._num_random_games)
                wandb_log.update({"global_step": self._nodes_visited})
                wandb_log.update({"avg_reward_on_random": avg_reward})
                wandb_log.update({"avg_win_against_random": avg_win_against_random})
                wandb_log.update(
                    {"fps": self._nodes_visited / (time.time() - algo_start_time)}
                )
                if self._use_wandb:
                    wandb.log(wandb_log)

        # Train policy network.
        policy_loss = self.learn_average_policy_network()
        return regret_losses, policy_loss, convs, nodes

    def save_policy_network(self, outputfolder):
        """Saves the policy network to the given folder."""
        os.makedirs(outputfolder, exist_ok=True)
        torch.save(
            self._policy_network.state_dict(), outputfolder + "/policy_network.pt"
        )
        return outputfolder + "/policy_network.pt"
        # self._policy_network.save(outputfolder)

    def train_policy_network_from_file(
        self,
        tfrecordpath,
        iteration=None,
        batch_size_average_policy=None,
        policy_network_train_steps=None,
        reinitialize_policy_network=True,
    ):
        """Trains the policy network from a previously stored tfrecords-file."""
        self._memories_tfrecordpath = tfrecordpath
        if iteration:
            self._iteration = iteration
        if batch_size_average_policy:
            self._batch_size_average_policy = batch_size_average_policy
        if policy_network_train_steps:
            self._policy_network_train_steps = policy_network_train_steps
        if reinitialize_policy_network:
            self._reinitialize_policy_network()
        policy_loss = self.learn_average_policy_network()
        return policy_loss

    def _add_to_average_policy_memory(
        self, info_state, iteration, average_policy_action_probs, legal_actions_mask
    ):
        # pylint: disable=g-doc-args
        """Adds the given average_policy data to the memory.

        Uses either a tfrecordsfile on disk if provided, or a reservoir buffer.
        """
        serialized_example = self._serialize_average_policy_memory(
            info_state, iteration, average_policy_action_probs, legal_actions_mask
        )
        if self._save_average_policy_memories:
            self._memories_tfrecordfile.write(serialized_example)
        else:
            self._average_policy_memories.add(serialized_example)

    def _serialize_average_policy_memory(
        self, info_state, iteration, average_policy_action_probs, legal_actions_mask
    ):
        """Create serialized example to store a average_policy entry."""

        return (info_state, iteration, average_policy_action_probs, legal_actions_mask)

    def _serialize_regret_memory(
        self, info_state, iteration, samp_regret, legal_actions_mask
    ):
        """Create serialized example to store an regret entry."""
        serialized = (info_state, iteration, samp_regret, legal_actions_mask)
        return serialized  # example.SerializeToString()

    def _serialize_value_memory(
        self, hist_state, iteration, samp_value, legal_actions_mask
    ):
        """Create serialized example to store a value entry."""
        serialized = (hist_state, iteration, samp_value, legal_actions_mask)
        return serialized  # example.SerializeToString()

    def _baseline(self, state, aidx):  # pylint: disable=unused-argument
        # Default to vanilla outcome sampling
        return 0

    def _baseline_corrected_child_value(
        self, state, sampled_aidx, aidx, child_value, sample_prob
    ):
        # Applies Eq. 9 of Schmid et al. '19
        baseline = self._baseline(state, aidx)
        if aidx == sampled_aidx:
            return baseline + (child_value - baseline) / sample_prob
        else:
            return baseline

    def _exact_value(self, state, update_player):
        # self._exact_value_num += 1
        # if self._exact_value_num % 1000 == 0:
        #     # print(self._exact_value_num)
        #     print(state.is_chance_node())
        state = state.clone()
        if state.is_terminal():
            ret = state.player_return(update_player)
            # print(ret)
            return ret  # state.player_return(update_player)
        if state.is_chance_node():
            outcomes, probs = zip(*state.chance_outcomes())
            val = 0
            for aidx in range(len(outcomes)):
                new_state = state.child(outcomes[aidx])
                val += probs[aidx] * self._exact_value(new_state, update_player)
            return val
        cur_player = state.current_player()
        legal_actions = state.legal_actions()
        num_legal_actions = len(legal_actions)
        _, policy = self._sample_action_from_regret(state, cur_player)
        val = 0
        for aidx in range(num_legal_actions):
            new_state = state.child(legal_actions[aidx])
            val += policy[aidx] * self._exact_value(new_state.clone(), update_player)
        return val

    def _get_balanced_probs(self, state):
        if state.is_terminal():
            return 1
        elif state.is_chance_node():
            legal_actions = state.legal_actions()
            num_nodes = 0
            for action in legal_actions:
                num_nodes += self._get_balanced_probs(state.child(action))
            return num_nodes
        else:
            legal_actions = state.legal_actions()
            num_nodes = 0
            balanced_probs = np.zeros((state.num_distinct_actions()))
            for action in legal_actions:
                nodes = self._get_balanced_probs(state.child(action))
                balanced_probs[action] = nodes
                num_nodes += nodes
            self._balanced_probs[state.information_state_string()] = (
                balanced_probs / balanced_probs.sum()
            )
            return num_nodes

    def _traverse_game_tree(
        self,
        state,
        player,
        my_reach,
        opp_reach,
        sample_reach,
        my_sample_reach,
        train_regret,
        train_value,
        on_policy_prob=0.0,
        track_mean_squares=True,
        expl=1.0,
        val_test=False,
        last_action=0,
    ):
        """Performs a traversal of the game tree using external sampling.

        Over a traversal the regret and average_policy memories are populated with
        computed regret values and matched regrets respectively if train_regret=True.
        If train_value=True then we use traversals to train the history value function.

        Args:
          state: Current OpenSpiel game state.
          player: (int) Player index for this traversal.

        Returns:
          Recursively returns expected payoffs for each action.
        """
        self._nodes_visited += 1
        if state.is_terminal():
            self._games_played += 1
            # print(self._games_played)
            return state.returns()[player], state.returns()[player]
        elif state.is_chance_node():
            # If this is a chance node, sample an action
            outcomes, probs = zip(*state.chance_outcomes())
            aidx = np.random.choice(range(len(outcomes)), p=probs)
            action = outcomes[aidx]
            new_state = state.child(action)
            return self._traverse_game_tree(
                new_state,
                player,
                my_reach,
                probs[aidx] * opp_reach,
                probs[aidx] * sample_reach,
                my_sample_reach,
                train_regret,
                train_value,
                expl=expl,
                track_mean_squares=track_mean_squares,
                val_test=val_test,
                last_action=action,
            )

        # with probability equal to op_prob, we switch over to on-policy rollout for remainder of trajectory
        # used for value estimation to get coverage but not needing importance sampling
        if expl != 0.0:
            if np.random.rand() < on_policy_prob:
                expl = 0.0

        cur_player = state.current_player()
        legal_actions = state.legal_actions()
        num_legal_actions = len(legal_actions)
        num_actions = state.num_distinct_actions()
        _, policy = self._sample_action_from_regret(state, state.current_player())

        if cur_player == player or train_value:
            uniform_policy = np.array(state.legal_actions_mask()) / num_legal_actions
            if self._use_balanced_probs:
                uniform_policy = self._balanced_probs[state.information_state_string()]
            sample_policy = expl * uniform_policy + (1.0 - expl) * policy
        else:
            sample_policy = policy

        sample_policy /= sample_policy.sum()
        sampled_action = np.random.choice(
            range(state.num_distinct_actions()), p=sample_policy
        )
        orig_state = state.clone()
        new_state = state.child(sampled_action)
        child_value = self._estimate_value_from_hist(
            new_state.clone(), player, last_action=sampled_action
        )
        value_estimate = self._estimate_value_from_hist(
            state.clone(), player, last_action=last_action
        )

        if track_mean_squares:
            print("tracking mean squares")
            oracle_child_value = self._exact_value(new_state.clone(), player)
            oracle_value_estimate = self._exact_value(state.clone(), player)
            squared_error = (oracle_value_estimate - value_estimate) ** 2
            self._squared_errors.append(squared_error)
            squared_child_error = (oracle_child_value - child_value) ** 2
            self._squared_errors_child.append(squared_child_error)

        if cur_player == player:
            new_my_reach = my_reach * policy[sampled_action]
            new_opp_reach = opp_reach
            new_my_sample_reach = my_sample_reach * sample_policy[sampled_action]
        else:
            new_my_reach = my_reach
            new_opp_reach = opp_reach * policy[sampled_action]
            new_my_sample_reach = my_sample_reach
        new_sample_reach = sample_reach * sample_policy[sampled_action]
        iw_sampled_value, sampled_value = self._traverse_game_tree(
            new_state,
            player,
            new_my_reach,
            new_opp_reach,
            new_sample_reach,
            new_my_sample_reach,
            train_regret,
            train_value,
            expl=expl,
            track_mean_squares=track_mean_squares,
            val_test=val_test,
            last_action=sampled_action,
        )
        importance_weighted_sampled_value = (
            iw_sampled_value * policy[sampled_action] / sample_policy[sampled_action]
        )

        # Compute each of the child estimated values.
        child_values = np.zeros(num_actions, dtype=np.float64)
        if self._all_actions:
            for aidx in range(num_legal_actions):
                cloned_state = orig_state.clone()
                action = legal_actions[aidx]
                new_cloned_state = cloned_state.child(action)
                child_values[action] = self._estimate_value_from_hist(
                    new_cloned_state.clone(), player, last_action=action
                )
        else:
            child_values[sampled_action] = child_value / sample_policy[sampled_action]

        if train_regret:
            if cur_player == player:
                cf_action_values = 0 * policy
                for action in range(num_actions):
                    if self._importance_sampling:
                        action_sample_reach = (
                            my_sample_reach * sample_policy[sampled_action]
                        )
                        cf_value = value_estimate * min(
                            1 / my_sample_reach, self._importance_sampling_threshold
                        )
                        cf_action_value = child_values[action] * min(
                            1 / action_sample_reach, self._importance_sampling_threshold
                        )
                    else:
                        cf_action_value = child_values[action]
                        cf_value = value_estimate
                    cf_action_values[action] = cf_action_value

                samp_regret = (cf_action_values - cf_value) * state.legal_actions_mask(
                    player
                )

                network_input = state.information_state_tensor()

                self._regret_memories[player].add(
                    self._serialize_regret_memory(
                        network_input,
                        self._iteration,
                        samp_regret,
                        state.legal_actions_mask(player),
                    )
                )
            else:
                obs_input = state.information_state_tensor(cur_player)

                self._add_to_average_policy_memory(
                    obs_input,
                    self._iteration,
                    policy,
                    state.legal_actions_mask(cur_player),
                )

        # value function predicts value for player 0
        if train_value:
            # if op_prob = 0 then we are doing importance weighted sampling
            # if op_prob > 0 then we need to wait until expl = 0 to get pure on-policy rollouts
            if on_policy_prob == 0 or expl == 0:
                hist_state = np.append(
                    state.information_state_tensor(0), state.information_state_tensor(1)
                )

                assert player == 0
                if self._val_bootstrap:
                    if self._all_actions:
                        target = policy @ child_values
                    else:
                        target = (
                            child_value
                            * policy[sampled_action]
                            / sample_policy[sampled_action]
                        )
                elif self._debug_val:
                    target = (
                        child_value
                        * policy[sampled_action]
                        / sample_policy[sampled_action]
                    )
                    print(target, "value target")
                else:
                    target = iw_sampled_value
                if val_test:
                    self._value_memory_test.add(
                        self._serialize_value_memory(
                            hist_state,
                            self._iteration,
                            target,
                            state.legal_actions_mask(cur_player),
                        )
                    )
                else:
                    self._value_memory.add(
                        self._serialize_value_memory(
                            hist_state,
                            self._iteration,
                            target,
                            state.legal_actions_mask(cur_player),
                        )
                    )

        return importance_weighted_sampled_value, sampled_value

    def _get_matched_regrets(self, info_state, legal_actions_mask, player):
        """PyTorch version to calculate regret matching."""
        # Assuming _regret_networks[player] is a PyTorch model and info_state and legal_actions_mask are PyTorch tensors
        # PyTorch models expect batch dimensions, so ensure info_state is properly batched
        regrets = self._regret_networks[player](
            (info_state.unsqueeze(0), legal_actions_mask)
        )[0]

        regrets = torch.clamp(regrets, min=0)
        summed_regret = torch.sum(regrets)
        if summed_regret > 0:
            matched_regrets = regrets / summed_regret
        else:
            very_large_negative_number = torch.tensor(-1e20, device=regrets.device)
            where_condition = torch.where(
                legal_actions_mask.to(regrets.device) == 1,
                regrets,
                very_large_negative_number,
            )
            # For one-hot encoding, assuming the dimensionality is correct
            matched_regrets = F.one_hot(
                torch.argmax(where_condition), num_classes=self._num_actions
            ).float()

        return regrets, matched_regrets

    def _get_estimated_value(self, hist_state, legal_actions_mask):
        estimated_val = self._val_network(
            (torch.unsqueeze(hist_state, 0), legal_actions_mask)
        )[0]
        return estimated_val

    def _sample_action_from_regret(self, state, player):
        """Returns an info state policy by applying regret-matching.

        Args:
          state: Current OpenSpiel game state.
          player: (int) Player index over which to compute regrets.

        Returns:
          1. (np-array) regret values for info state actions indexed by action.
          2. (np-array) Matched regrets, prob for actions indexed by action.
        """

        info_state = torch.as_tensor(
            state.information_state_tensor(player), dtype=torch.float32
        )
        legal_actions_mask = torch.as_tensor(
            state.legal_actions_mask(player), dtype=torch.float32
        )
        self._example_info_state[player] = info_state
        self._example_legal_actions_mask[player] = legal_actions_mask
        regrets, matched_regrets = self._get_matched_regrets(
            info_state, legal_actions_mask, player
        )
        return regrets.cpu().detach().numpy(), matched_regrets.cpu().detach().numpy()

    def _estimate_value_from_hist(self, state, player, last_action=0):
        """Returns an info state policy by applying regret-matching.

        Args:
          state: Current OpenSpiel game state.
          player: (int) Player index over which to compute regrets.

        Returns:
          1. (np-array) regret values for info state actions indexed by action.
          2. (np-array) Matched regrets, prob for actions indexed by action.
        """
        state = state.clone()
        if state.is_terminal():
            return state.player_return(player)

        hist_state = np.append(
            state.information_state_tensor(0), state.information_state_tensor(1)
        )

        self._example_hist_state = hist_state
        hist_state = torch.as_tensor(hist_state, dtype=torch.float32)
        legal_actions_mask = torch.as_tensor(
            state.legal_actions_mask(player), dtype=torch.float32
        )
        estimated_value = self._get_estimated_value(hist_state, legal_actions_mask)
        if player == 1:
            estimated_value = -estimated_value
        return estimated_value.cpu().detach().numpy()

    def action_probabilities(self, state):
        """Returns action probabilities dict for a single batch."""
        cur_player = state.current_player()
        legal_actions = state.legal_actions(cur_player)
        legal_actions_mask = torch.as_tensor(
            state.legal_actions_mask(cur_player), dtype=torch.float32
        )

        info_state_vector = torch.as_tensor(
            state.information_state_tensor(), dtype=torch.float32
        )
        if len(info_state_vector.shape) == 1:
            info_state_vector = info_state_vector.unsqueeze(0)
        probs = self._policy_network(
            (info_state_vector, legal_actions_mask),
        )
        probs = probs.cpu().detach().numpy()
        return {action: probs[0][action] for action in legal_actions}

    def _get_regret_dataset(self, player):
        """Returns the collected regrets for the given player as a dataset."""
        regret_dataloader = DataLoader(
            self._regret_memories[player],
            batch_size=self._batch_size_regret,
            shuffle=True,
        )
        return regret_dataloader

    def _get_value_dataset(self):
        """Returns the collected value estimates for the given player as a dataset."""

        value_dataloader = DataLoader(
            self._value_memory, batch_size=self._batch_size_value, shuffle=True
        )
        return value_dataloader

    def _get_value_dataset_test(self):
        """Returns the collected value estimates for the given player as a dataset."""
        data = self._value_memory_test
        value_dataloader = DataLoader(
            data, batch_size=self._batch_size_value, shuffle=True
        )
        return value_dataloader

    def learn_value_network(self):
        self._val_network_train.to(self._train_device)
        tfit = torch.tensor(self._iteration, dtype=torch.float32).to(self._train_device)
        data_loader = self._get_value_dataset()
        losses = []
        iter(data_loader)

        self._value_memory.batch_data(device=self._train_device)
        for batch_idx in range(self._value_network_train_steps):
            full_hist_states, iterations, values, masks = (
                self._value_memory.batch_sample(self._batch_size_value)
            )
            values = values.unsqueeze(-1)
            loss = self.value_train_step(full_hist_states, values, masks, tfit)
            losses.append(loss.item())

        self._val_network.load_state_dict(self._val_network_train.state_dict())
        self._val_network.to(self._infer_device)
        return losses

    def value_train_step(self, full_hist_states, values, masks, tfit):
        self._val_network_train.train()
        self._optimizer_value.zero_grad()
        preds = self._val_network_train((full_hist_states, masks)).float()
        main_loss = self._loss_value(preds, values.float()).float()
        main_loss.backward()
        self._optimizer_value.step()
        return main_loss

    def learn_regret_network(self, player):
        self._regret_networks_train[player].to(self._train_device)
        tfit = torch.tensor(self._iteration, dtype=torch.float32).to(self._train_device)

        losses = []
        self._regret_memories[player].batch_data(device=self._train_device)
        for batch_idx in range(self._regret_network_train_steps):
            info_state, iteration, samp_regret, legal_actions_mask = (
                self._regret_memories[player].batch_sample(self._batch_size_regret)
            )
            data = (info_state, samp_regret, iteration, legal_actions_mask)
            losses.append(self.regret_train_step(player, *data, tfit))
        self._regret_networks[player].load_state_dict(
            self._regret_networks_train[player].state_dict()
        )
        self._regret_networks[player].to(self._infer_device)
        return losses

    def regret_train_step(
        self, player, info_states, regrets, iterations, masks, iteration
    ):
        model = self._regret_networks_train[player]
        model.train()
        model.zero_grad()
        preds = model(
            (
                info_states.to(self._train_device).float(),
                masks.to(self._train_device).float(),
            )
        )
        iterations = iterations.to(self._train_device)
        loss = self._loss_regrets[player](
            preds, regrets.to(self._train_device), weights=iterations * 2 / iteration
        )
        loss.backward()
        self._optimizer_regrets[player].step()
        return loss.item()

    def _get_average_policy_dataset(self):
        """Returns the collected average_policy memories as a dataset."""
        data = DataLoader(
            self._average_policy_memories,
            batch_size=self._batch_size_average_policy,
            shuffle=True,
        )
        return data

    def learn_average_policy_network(self):
        """Compute the loss over the average_policy network.

        Returns:
          The average loss obtained on the last training batch of transitions
          or `None`.
        """
        self._policy_network.to(self._train_device)
        self._average_policy_memories.batch_data(device=self._train_device)
        # data = self._get_average_policy_dataset()
        for batch_idx in range(self._policy_network_train_steps):
            # info_state, iteration, average_policy_action_probs, legal_actions_mask = next(iter(data))
            # info_state = torch.stack(info_state, dim=-1).to(self._train_device).float()
            # iteration = iteration.to(self._train_device)
            # average_policy_action_probs = average_policy_action_probs.to(self._train_device).float()
            # legal_actions_mask = torch.stack(legal_actions_mask, dim=-1).to(self._train_device).float()
            info_state, iteration, average_policy_action_probs, legal_actions_mask = (
                self._average_policy_memories.batch_sample(
                    self._batch_size_average_policy
                )
            )
            d = (info_state, average_policy_action_probs, iteration, legal_actions_mask)
            main_loss = self.average_policy_train_step(*d)
            # print(main_loss)

        return main_loss

    def average_policy_train_step(self, info_states, action_probs, iterations, masks):
        model = self._policy_network
        model.train()
        self._optimizer_policy.zero_grad()
        preds = model(
            (info_states.to(self._train_device), masks.to(self._train_device))
        )
        weight = iterations.to(self._train_device) * 2 / self._iteration
        main_loss = self._loss_policy(
            preds, action_probs.to(self._train_device), weight
        )
        main_loss.backward()
        self._optimizer_policy.step()
        return main_loss
