import torch
from tqdm import trange

from chess_utils import tensor_to_board, tb_probe_result, softmax, tb_probe_wdl, tb_probe_wdl_ab, tb_res_to_wdl
import log_writer
from chess_env import OutcomeEnv, MoveEnv
from evaluate import evaluate_board
import config


def move_loss(board, limit):
    pre_label = int(tb_probe_wdl(board))
    # pre_label = 2 - tb_probe_result(board)
    if pre_label == -2:
        return 0.0, "null"
    engine_move = evaluate_board(board, limit=limit)[1]
    if not board.is_legal(engine_move):
        return 2.0, None
    board.push(engine_move)
    if int(pre_label / 2) == 0:
        pre_label = -1
    post_label = -int(tb_probe_wdl_ab(board, alpha=-pre_label))
    board.pop()
    return 1.0 * (int(pre_label / 2) - int(post_label / 2)), engine_move


def log_gfn(gflownet, env, target_model, iteration, num_samples=2000, log_boards=True):
    legal = 0
    model_error = 0
    max_error = 0
    reg_error_count = 0
    max_error_count = 0
    board_list = list()
    success_boards = list()
    success_fens = list()
    win_to_loss_boards = list()
    if isinstance(env, MoveEnv):
        log_writer.save_fens(env.good_fens, "fens", iteration, path="fens_v2")
        env.good_fens.clear()
        log_writer.save_fens(env.illegal_moves, "fens", iteration, path="illegal_moves")
        env.illegal_moves.clear()
    sample_error = None
    with torch.no_grad():
        states = gflownet.sample_terminating_states(env, num_samples)
        board_tensor = states.tensor
        valid_boards = list()
        for idx in range(board_tensor.shape[0]):
            board = tensor_to_board(board_tensor[idx])
            if not board.is_valid() or board.is_game_over():
                continue
            if len(valid_boards) < 8:
                valid_boards.append(board)
            board_list.append(board_tensor[idx])
            legal += 1

            if isinstance(env, OutcomeEnv):
                label = torch.tensor([tb_probe_result(board)], dtype=torch.int64)
                label_v = 1 - label
                pred = softmax(target_model(board_tensor[idx].view(1, 12 * 64)))
                pred_v = pred[0, 0] - pred[0, 2]
                sample_error = torch.abs(label_v - pred_v).item()
            elif isinstance(env, MoveEnv):
                sample_error, engine_move = move_loss(board, env.engine_limit)
                if sample_error >= 1.0:
                    if engine_move is None:
                        success_fens.append((board.fen(), "illegal"))
                    else:
                        success_fens.append((board.fen(), engine_move.uci()))
                    reg_error_count += 1
                    if len(success_boards) < 8:
                        success_boards.append(board)
                if sample_error == 2.0:
                    max_error_count += 1
                    if len(win_to_loss_boards) < 8:
                        win_to_loss_boards.append(board)

            if sample_error is not None:
                model_error += sample_error
                max_error = sample_error if sample_error > max_error else max_error
    if len(valid_boards) == 8 and log_boards:
        log_writer.log_boards("val/boards", valid_boards, iteration)
        log_writer.log_boards("val/good_board", success_boards, iteration)
        log_writer.log_boards("val/win_loss_boards", win_to_loss_boards, iteration)
    log_writer.add_board_tensor_dist("val/dist", board_tensor.view(-1, 12, 64), iteration)

    log_writer.add_scalar("val/legal", 1.0 * legal / board_tensor.shape[0], iteration)
    if sample_error is not None:
        log_writer.add_scalar("val/target_error", model_error / board_tensor.shape[0], iteration)
        log_writer.add_scalar("val/max_error", max_error, iteration)
    if isinstance(env, MoveEnv):
        env.good_fens.clear()
        env.illegal_moves.clear()
        log_writer.add_scalar("val/error_percent", 100.0 * reg_error_count / num_samples, iteration)
        log_writer.add_scalar("val/win_to_loss_percent", 100.0 * max_error_count / num_samples, iteration)
    if legal > 0:
        print(f"Legal: {legal}/2000, Model Error: {model_error}, avg legal error: {model_error / legal}")
        log_writer.add_board_tensor_dist("val/legal_dist", torch.stack(board_list).view(-1, 12, 64), iteration)
        log_writer.add_scalar("val/target_legal_error", model_error / legal, iteration)
    log_writer.add_scalar("continuations", config.continuations, iteration)
    log_writer.save_fens(success_fens, "fens", iteration)


def train_target_model(gflownet, env, model, optimizer, batch_size):
    gflownet.eval()
    model.train()
    with torch.no_grad():
        samples = gflownet.sample_terminating_states(env, batch_size)
    features = list()
    labels = list()
    for idx in range(batch_size):
        board = tensor_to_board(samples.tensor[idx, :768])
        if board.is_valid():
            labels.append(torch.tensor([tb_probe_result(board)], dtype=torch.int64))
            features.append(samples.tensor[idx:(idx + 1), :768])
    if len(features) > 0:
        features = torch.cat(features)
        labels = torch.cat(labels)
        optimizer.zero_grad()
        preds = model(features)
        loss = torch.nn.functional.cross_entropy(preds, labels, reduction='sum')
        loss.backward()
        optimizer.step()
    model.eval()


def train(gflownet, optimizer, env, target_model=None, target_optimizer=None, batch_size=128, n_episodes=25_000,
          log_frequency=None, max_temp=32, min_temp=1, temp_decay=0.996, sampler=None, start=None):
    """Training loop, keeping track of terminal states over training."""

    if start is None:
        log_gfn(gflownet, env, target_model, 0, log_boards=True)
        log_writer.save_checkpoint(gflownet, optimizer, 0, max_temp, name="ckpt.tar")
        start = 0
    states_visited = 0
    gflownet.pf.epsilon = 0.00
    tscale = max_temp - min_temp
    it_mult = 1024 // batch_size if batch_size < 1024 else 1
    for iteration in trange(start, start + (n_episodes // batch_size)):
        gflownet.pf.temperature = min_temp + tscale

        if target_optimizer is not None:
            train_target_model(gflownet, env, target_model, target_optimizer, batch_size)

        gflownet.train()
        trajectories = sampler.sample_trajectories(env=env, n_trajectories=batch_size)
        optimizer.zero_grad()
        loss = gflownet.loss(env, trajectories)
        loss.backward()
        optimizer.step()

        states_visited += len(trajectories)
        if iteration % it_mult == 0:
            log_writer.add_scalar("train/loss", loss.item(), iteration // it_mult)

        if iteration % it_mult == 0:
            tscale *= temp_decay

        log_iteration = iteration + 1
        if log_frequency is not None and log_iteration % (it_mult * log_frequency) == 0:
            extra_log = log_iteration % (it_mult * log_frequency * 10) == 0
            log_gfn(gflownet, env, target_model, (log_iteration // it_mult), log_boards=extra_log)
            log_writer.save_checkpoint(gflownet, optimizer, log_iteration, tscale + min_temp, name="ckpt.tar")


    return states_visited
