import argparse
import json
import os
from typing import Callable, Dict, List, Optional, Tuple, Type, Union

import torch
from gymnasium import spaces
import torch as th
from torch import nn
import chess

from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.vec_env import DummyVecEnv

from models.resnet import ChessResNet
from mygym import MoveGymEnv
from evaluate import close_engine
import config


def store_fens(fens, iteration_count, setting, run, base="fens"):
    fen_path = os.path.join("..", base, setting, run)
    if not os.path.exists(fen_path):
        os.makedirs(fen_path)
    with open(os.path.join(fen_path, f"fens_{iteration_count}.csv"), 'w') as f:
        for fen in fens:
            if isinstance(fen, str):
                f.write(f"{fen}, \n")
            else:
                assert isinstance(fen, tuple)
                f.write(f"{fen[0]}, {fen[1]}, \n")


def main():
    parser = argparse.ArgumentParser(description='Train GFN to target model')
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--log-frequency', type=int, default=20,
                        help="Controls size of interval between tensorboard log updates.")
    parser.add_argument('--tb-path', type=str, default="/home/Chess/TB_Merged",
                        help='Path to find EGTB')
    parser.add_argument('--depth', type=int, default=1, help='Depth for engine to search in experiments.' 
                                                             'Relevance depends on experiment type.')
    parser.add_argument('--nodes', type=int, default=None, help='Node limit for engine in experiments.')
    parser.add_argument('--engine', type=str, default='Stockfish', help='Target engine')
    parser.add_argument('--episodes', type=int, default=25e6)
    parser.add_argument('--learning-rate', type=float, default=3e-4)
    # parser.add_argument('--base-reward', type=float, default=0.1,
    #                     help="Extra reward for generating balanced samples")
    # parser.add_argument('--reward-balance', type=float, default=0.9,
    #                     help="Extra reward for generating balanced samples")
    # parser.add_argument('--reward-fool', type=float, default=125,
    #                     help="Extra reward for generating samples that fool the target model")
    parser.add_argument('--load', type=str, default=None,
                         help="Path to load initial model weights from.")
    parser.add_argument('--num-pieces', type=int, default=5,
                        help="Number of pieces the GFN should output per position.")
    parser.add_argument('--block-sizes', nargs='+', type=int, default=[128, 128, 128, 128, 128, 128, 128, 64],
                        help="Residual block sizes of model bodies.")
    # parser.add_argument('--uci-separate', default=False, action=argparse.BooleanOptionalAction)
    # parser.add_argument('--gen-fens-only', default=False, action=argparse.BooleanOptionalAction)
    args = parser.parse_args()
    print(args)

    config.tb_path = args.tb_path
    config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.set_default_device(config.device)

    # Create and wrap the environment
    env = MoveGymEnv(engine=args.engine, num_pieces=args.num_pieces, nodes=args.nodes)
    max_episode_steps = args.num_pieces - 2  # Replace with the appropriate max steps for your env
    env = DummyVecEnv([lambda: env])

    setting_str = f"{args.engine}_us__move_n{args.nodes}_{args.num_pieces}"
    run_str = f"{args.engine}_move_n{args.nodes}_{args.num_pieces}_ppo"

    SAVE_PATH = os.path.join("..", "checkpoints", setting_str, run_str, "ckpt")
    COUNTER_FILE = os.path.join("..", "checkpoints", setting_str, run_str, "counter.json")

    # Initialize PPO model with GPU if available
    # Check if existing model and counter exist
    if args.load:
        print("Loading model and counter...")
        model = PPO.load(SAVE_PATH, env=env, device=config.device)
        with open(COUNTER_FILE, "r") as f:
            iteration_count = json.load(f).get("iteration", 0)
    else:
        print("Initializing a new model...")
        model = PPO(
            CustomActorCriticPolicy,
            env,
            verbose=1,
            device=config.device,
            batch_size=args.batch_size,
        )
        iteration_count = 0
    # model = PPO(CustomActorCriticPolicy, env, verbose=1, device=config.device,
    #            batch_size=args.batch_size)



    # model = PPO("MlpPolicy", env, verbose=1, device="cuda" if th.cuda.is_available() else "cpu",
    #            tensorboard_log=os.path.join("..", "ppo_log"))  # , n_steps=1920,
    # learning_rate=1e-4, batch_size=256)

    # Training loop as before
    total_iterations = 7200
    iterations_per_update = 5
    episodes_per_cycle = 500

    while iteration_count < total_iterations:
        iteration_count += iterations_per_update
        model.learn(total_timesteps=max_episode_steps * iterations_per_update * 1024)
        print(f"Completed {iteration_count} iterations of model updates")
        # env.success = None
        print(len(env.envs))
        total_reward_sum = 0
        store_fens(env.envs[0].success, iteration_count, setting_str, run_str, base="fens_v2")
        env.envs[0].success.clear()

        for episode in range(episodes_per_cycle):
            obs = env.reset()
            # king_positions = env.envs[0].states.nonzero()
            # kings =
            done = False
            total_rewards = 0
            actions = list()
            while not done:
                action, _states = model.predict(obs, deterministic=False)
                obs, reward, done, info = env.step(action)
                actions.append(action.item())
                # print(action)
                total_rewards += reward
                assert done or reward == 0

            # print(f"Reward: {total_rewards}, Actions: {actions}")
            assert len(env.envs) == 1
            total_reward_sum += total_rewards
        print(f"Episode {episode + 1}: Total Reward = {total_reward_sum}")

        store_fens(env.envs[0].success, iteration_count, setting_str, run_str, base="fens")

        # print(f"Total Rewards: {total_rewards}")
        print(f"Successes: {env.envs[0].success}")
        print(f"Completed {episodes_per_cycle} episodes after {iteration_count} iterations.")
        env.envs[0].success.clear()

        print(f"Saving model and iteration counter at iteration {iteration_count}")
        model.save(SAVE_PATH)
        with open(COUNTER_FILE, "w") as f:
            json.dump({"iteration": iteration_count}, f)



    print("Training completed!")

    close_engine()


class CustomNetwork(nn.Module):
    """
    Custom network for policy and value function.
    It receives as input the features extracted by the features extractor.

    :param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
    :param last_layer_dim_pi: (int) number of units for the last layer of the policy network
    :param last_layer_dim_vf: (int) number of units for the last layer of the value network
    """

    def __init__(
        self,
        feature_dim: int,
        last_layer_dim_pi: int = 64*10,
        last_layer_dim_vf: int = 64*10,
    ):
        super().__init__()

        # IMPORTANT:
        # Save output dimensions, used to create the distributions
        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf
        layer_sizes = [256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256]

        # Policy network
        self.policy_net = nn.Sequential(
            ChessResNet(layer_sizes, num_classes=10 * 64, out_method="PB")
        )
        # Value network
        self.value_net = nn.Sequential(
            ChessResNet([], num_classes=10 * 64, out_method="PB", torso=self.policy_net[0].res_blocks,
                        in_planes=self.policy_net[0].in_planes)
        )
        self.policy_net = self.policy_net.to(config.device)
        self.value_net = self.value_net.to(config.device)

    def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
        """
        :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
            If all layers are shared, then ``latent_policy == latent_value``
        """
        return self.forward_actor(features), self.forward_critic(features)

    def invalid_action_mask(self, features: th.Tensor) -> th.Tensor:
        board_tensor = features.view(-1, 12, 64)
        valid_action_mask = torch.ones_like(board_tensor)
        valid_action_mask[..., 0, 0:8] = 0
        valid_action_mask[..., 0, (7 * 8):(8 * 8)] = 0
        valid_action_mask[..., 6, 0:8] = 0
        valid_action_mask[..., 6, (7 * 8):(8 * 8)] = 0
        square_counts = board_tensor.sum(dim=-2)
        square_empty = (square_counts == 0).view(-1, 1, 64)
        valid_action_mask *= square_empty
        invalid_mask = (1 - valid_action_mask) * (-100)
        return torch.cat([invalid_mask[:, :chess.QUEEN],
                   invalid_mask[:, chess.KING:(chess.KING + chess.QUEEN)]], dim=1).view(-1, 640)

    def forward_actor(self, features: th.Tensor) -> th.Tensor:
        policy = self.policy_net(features)
        # print(f"Min: {policy.min()}, max: {policy.max()}")
        invalid_mask = self.invalid_action_mask(features)
        return policy + invalid_mask

    def forward_critic(self, features: th.Tensor) -> th.Tensor:
        return self.value_net(features)


class CustomActorCriticPolicy(ActorCriticPolicy):
    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Callable[[float], float],
        *args,
        **kwargs,
    ):
        # Disable orthogonal initialization
        # self.layer_sizes = layer_sizes
        kwargs["ortho_init"] = False
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            # Pass remaining arguments to base class
            *args,
            **kwargs,
        )

    def _build_mlp_extractor(self) -> None:
        self.mlp_extractor = CustomNetwork(self.features_dim)


# environment = MoveGymEnv()
# model = PPO(CustomActorCriticPolicy, environment, verbose=1)
# model.learn(5000)
# close_engine()

if __name__ == '__main__':
    print("torch version", torch.__version__)
    main()
