import argparse
import os

import torch
from gfn.gflownet import TBGFlowNet, SubTBGFlowNet
from gfn.modules import DiscretePolicyEstimator, ScalarEstimator
from gfn.samplers import Sampler

from models.resnet import ChessResNet
from chess_utils import open_egtb, close_egtb
from models.temperature_policy_estimator import TemperaturePolicyEstimator
import config
import log_writer
from chess_env import OutcomeEnv, MoveEnv, ChessEnv
from evaluate import close_engine, initialize_engine
from fen_generator import gen_fens
from uniform_sampler import uniform_to_engine_errors, sample_and_store_uniform
from models.winter_net import WinterNet
from train import train
from utils import get_default_experiment_name, backup_sources


def main():
    parser = argparse.ArgumentParser(description='Train GFN to target model')
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--log-frequency', type=int, default=20,
                        help="Controls size of interval between tensorboard log updates.")
    parser.add_argument('--tb-path', type=str, default="/home/Chess/TB_Merged",
                        help='Path to find EGTB')
    parser.add_argument('--name', type=str, default=get_default_experiment_name())
    parser.add_argument('--experiment', type=str, default='Move')
    parser.add_argument('--depth', type=int, default=1, help='Depth for engine to search in experiments.' 
                                                             'Relevance depends on experiment type.')
    parser.add_argument('--nodes', type=int, default=None, help='Node limit for engine in experiments.')
    parser.add_argument('--engine', type=str, default='Winter', help='Target engine')
    parser.add_argument('--episodes', type=int, default=25e6)
    parser.add_argument('--learning-rate', type=float, default=3e-5)
    parser.add_argument('--lr-mult', type=float, default=10,
                        help="LogR/LogF estimator learning rate multiplier.")
    parser.add_argument('--max-temp', type=float, default=32,
                        help="Initial temperature for GFN forward estimates")
    parser.add_argument('--min-temp', type=float, default=1,
                        help="Temperature for GFN forward estimates as iterations tend towards infinite.")
    parser.add_argument('--temp-decay', type=float, default=0.996,
                        help="Rate at which temperature decays per 1024 samples")
    parser.add_argument('--base-reward', type=float, default=0.1,
                        help="Extra reward for generating balanced samples")
    parser.add_argument('--reward-balance', type=float, default=0.9,
                        help="Extra reward for generating balanced samples")
    parser.add_argument('--reward-fool', type=float, default=125,
                        help="Extra reward for generating samples that fool the target model")
    parser.add_argument('--load', type=str, default=None,
                        help="Path to load initial model weights from.")
    parser.add_argument('--num-pieces', type=int, default=5,
                        help="Number of pieces the GFN should output per position.")
    parser.add_argument('--block-sizes', nargs='+', type=int, default=[128, 128, 128, 128, 128, 128, 128, 64],
                        help="Residual block sizes of model bodies.")
    parser.add_argument('--uci-separate', default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument('--gen-fens-only', default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument('--uniform-kings', default=True, action=argparse.BooleanOptionalAction)
    parser.add_argument('--gen-uniform-simple', default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument('--gus-start', type=int, default=0)
    parser.add_argument('--gus-max', type=int, default=20000)
    parser.add_argument('--uniform-on-engine', default=False, action=argparse.BooleanOptionalAction)
    args = parser.parse_args()
    print(args)

    config.tb_path = args.tb_path
    if args.gen_uniform_simple:
        open_egtb()
        sample_and_store_uniform(start_iter=args.gus_start, max_iterations=args.gus_max,
                                 num_pieces=args.num_pieces)
        close_egtb()
        return 0

    if args.uniform_on_engine:
        open_egtb()
        initialize_engine(args.engine, uci_separate=True)
        uniform_to_engine_errors(args.engine, batch=args.gus_start, count=args.gus_max, num_pieces=args.num_pieces,
                                 nodes=args.nodes)
        close_engine()
        close_egtb()
        return 0


    config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.set_default_device(config.device)

    # Fixed hyperparameters.
    n_episodes = int(args.episodes)
    learning_rate = args.learning_rate

    model = None
    if args.experiment.lower() == "uniform":
        env = ChessEnv(num_pieces=args.num_pieces, device=config.device, base_reward=args.base_reward, reward_balance=0)
    elif args.experiment.lower() == "weighted":
        env = ChessEnv(num_pieces=args.num_pieces, device=config.device, base_reward=args.base_reward,
                       reward_balance=args.reward_balance)
    elif args.experiment.lower() == "move":
        print(f"Uniform kings is {args.uniform_kings}")
        env = MoveEnv(num_pieces=args.num_pieces, device=config.device, depth=args.depth, nodes=args.nodes,
                      base_reward=args.base_reward, reward_balance=args.reward_balance, reward_fool=args.reward_fool,
                      uniform_kings=args.uniform_kings)
        initialize_engine(args.engine, uci_separate=args.uci_separate)
    else:
        assert args.experiment.lower() == "outcome"
        model = WinterNet()
        model.load_state_dict(torch.load("../pretrained/rn16HD64b_ep4.pt", map_location=torch.device("cpu"),
                                         weights_only=True))
        model.eval()
        model = model.to(config.device)

        env = OutcomeEnv(target_model=model, num_pieces=args.num_pieces, device=config.device)

    # Estimator for the forward policies.
    #pf_module = CACNN([12, 16, 16, 16, 16, 16, 16, 12], True, 1)
    # pf_module = ChessResNet([32, 32, 32, 32, 32, 32], num_classes=12 * 64 + 1, out_method="PF")
    # pf_module = ChessResNet([64, 64, 64, 64, 64, 64, 64], num_classes=12 * 64 + 1, out_method="PF")
    pf_module = ChessResNet(args.block_sizes, num_classes=env.n_actions, out_method="PF")
    pf_estimator = TemperaturePolicyEstimator(
        module=pf_module,
        n_actions=env.n_actions,
        preprocessor=env.preprocessor,
    )

    # Estimator for the backward policies.
    # pb_module = CACNN([12, 16, 16, 16, 16, 16, 16, 12], True, 0)
    # pb_module = CACNN([12, 16, 16, 16, 16, 16, 16, 12], True, 0)
    # pb_module = ChessResNet([32, 32, 32, 32, 32, 32], num_classes=12 * 64, out_method="PB")
    # pb_module = ChessResNet([128, 128, 128, 128, 128, 128, 128, 64], num_classes=10 * 64, out_method="PB")
    pb_module = ChessResNet([], num_classes=env.n_actions - 1, out_method="PB", torso=pf_module.res_blocks,
                            in_planes=pf_module.in_planes)
    pb_estimator = DiscretePolicyEstimator(
        module=pb_module,
        n_actions=env.n_actions,
        is_backward=True,
        preprocessor=env.preprocessor,
    )

    z_module = ChessResNet(args.block_sizes, num_classes=1, out_method="Linear",
                           torso=pf_module.res_blocks, in_planes=pf_module.in_planes)
    logF_estimator = ScalarEstimator(module=z_module, preprocessor=env.preprocessor)

    # 4 - We define the GFlowNet.
    gflownet = SubTBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF_estimator, lamda=0.8)

    gflownet = gflownet.to(config.device)

    # Get all parameters from the shared body (excluding the head/logF)
    shared_body_params = [v for k, v in gflownet.named_parameters() if "logF" not in k]

    # Get the head (logF) parameters separately
    logf_params = [v for k, v in gflownet.named_parameters() if "logF" in k]

    # Define the optimizer with separate learning rates for the shared body and the head
    optimizer = torch.optim.Adam([
        {"params": shared_body_params, "lr": learning_rate},  # Shared network body
        {"params": logf_params, "lr": args.lr_mult * learning_rate}  # Separate head for logF
    ])

    start_iteration = None
    temperature = args.max_temp
    if args.load is not None:
        if args.load == "auto":
            eng_setting_str = "_us_" if args.uci_separate else ""
            eng_limit_str = f"{args.depth}" if args.nodes is None else f"n{args.nodes}"
            subpath = f"{args.engine}{eng_setting_str}_move_{eng_limit_str}_{args.num_pieces}"
            if args.experiment.lower() == "uniform" or args.experiment.lower() == "weighted":
                subpath = args.name
            start_iteration, temperature = log_writer.load_checkpoint(gflownet, optimizer,
                                                                      os.path.join(subpath, args.name, "ckpt.tar"))
        elif args.load.endswith(".pt"):
            gflownet.load_state_dict(torch.load(os.path.join("..", "models", args.load), weights_only=False))
        elif args.load.endswith(".tar"):
            start_iteration, temperature = log_writer.load_checkpoint(gflownet, optimizer, args.load)
        else:
            gflownet.load_state_dict(torch.load(os.path.join("..", "models", args.load, "gflownet.pt"),
                                                weights_only=False))

    # 5 - We define the sampler and the optimizer.
    sampler = Sampler(estimator=pf_estimator)

    if args.gen_fens_only:
        gen_fens(gflownet, env, 10000, 2000)
        exit(0)

    open_egtb()
    if args.experiment.lower() == "weighted" or args.experiment.lower() == "uniform":
        log_writer.init(log_path=os.path.join("..", "tensorboard", f"{args.experiment.lower()}_{args.num_pieces}"),
                        name=args.name)
    elif args.experiment.lower() == "move":
        eng_setting_str = "_us_" if args.uci_separate else ""
        limit_str = f"{args.depth}" if args.nodes is None else f"n{args.nodes}"
        setting_str = f"{args.engine}{eng_setting_str}_move_{limit_str}_{args.num_pieces}"
        log_writer.init(log_path=os.path.join("..", "tensorboard", setting_str), name=args.name)
        backup_sources(os.path.join(setting_str, args.name))
    else:
        log_writer.init(log_path=os.path.join("..", "tensorboard", f"{args.engine}_outcome_{args.num_pieces}"),
                        name=args.name)
    train(gflownet, optimizer, env, target_model=model, batch_size=args.batch_size, n_episodes=n_episodes,
          log_frequency=args.log_frequency, sampler=sampler,
          max_temp=temperature, min_temp=args.min_temp, temp_decay=args.temp_decay, start=start_iteration)
    log_writer.close()
    close_engine()
    close_egtb()

    return 0


if __name__ == '__main__':
    print("torch version", torch.__version__)
    main()
