import csv
import os

import chess
import torch

from chess_utils import tb_probe_wdl, tb_probe_wdl_ab
from evaluate import evaluate_board


def _store_fens(fens, iteration_count, num_pieces, suffix="simple", base="fens_v2"):
    fen_path = os.path.join("..", base, f"uniform_{num_pieces}", suffix)
    if not os.path.exists(fen_path):
        os.makedirs(fen_path)
    with open(os.path.join(fen_path, f"fens_{iteration_count}.csv"), 'w') as f:
        for fen in fens:
            if isinstance(fen, str):
                f.write(f"{fen}, \n")
            else:
                assert isinstance(fen, tuple)
                f.write(f"{fen[0]}, {fen[1]}, \n")


def gen_simple_uniform(num_pieces):
    # Exact uniform process described in the paper.
    board = chess.Board(None)
    pieces = torch.randperm(768)[:num_pieces]
    for piece in pieces:
        color = piece // (64 * 6)
        piece_type = int((piece // 64) % 6) + 1
        square = piece % 64
        board.set_piece_at(square, chess.Piece(piece_type, color))
    return board


def sample_and_store_uniform(start_iter=0, max_iterations=20000, iteration_step_size=20, num_samples=1024,
                             num_pieces=5):
    num_samples *= iteration_step_size
    iteration = start_iter
    final_iteration = start_iter + max_iterations
    while iteration < final_iteration:
        iteration += iteration_step_size
        boards = list()
        for _ in range(num_samples):
            board = gen_simple_uniform(num_pieces)
            if board.is_valid() and not board.is_game_over():
                boards.append(board)
        fens_and_labels = [(board.fen(), tb_probe_wdl(board)) for board in boards]
        _store_fens(fens_and_labels, iteration, num_pieces, suffix="permmute")
        print(f"Stored {len(fens_and_labels)} samples for iteration {iteration}")


def gen_uniform_board(num_pieces):
    squares = torch.randperm(64)[:num_pieces]
    board = chess.Board(None)
    # To optimize chance of getting a valid board we generate king separately
    board.set_piece_at(squares[-1], chess.Piece(chess.KING, chess.WHITE))
    board.set_piece_at(squares[-2], chess.Piece(chess.KING, chess.BLACK))
    for idx in range(num_pieces - 2):
        piece_type = torch.randint(chess.QUEEN, (1,)).item() + 1
        color = torch.randint(2, (1,)).item()
        board.set_piece_at(squares[idx], chess.Piece(piece_type, color))
    return board


def get_simple_fen_filenames(directory, start=0, max_files=None):
    num_files = len(os.listdir(directory))
    max_files = num_files if max_files is None else max_files
    start *= max_files
    return [f"fens_{20 * (i + 1)}.csv" for i in range(start, min(start + max_files, num_files))]


def uniform_to_engine_errors(engine, batch=0, count=None, num_pieces=5, nodes=400, base_dir="../fens_v2"):
    file_path = os.path.join(base_dir, f"uniform_{num_pieces}", "permmute")
    files_to_process = get_simple_fen_filenames(file_path, start=batch, max_files=count)
    limit = chess.engine.Limit(nodes=nodes)
    setting_str = f"{engine}_us__move_n{nodes}_{num_pieces}"
    fen_path = os.path.join(base_dir, setting_str, "persimple")
    if not os.path.exists(fen_path):
        os.makedirs(fen_path)

    for filename in files_to_process:
        if not os.path.isfile(os.path.join(file_path, filename)):
            print(f"file {filename} not in {file_path}")
            continue
        engine_errors = list()
        with open(os.path.join(file_path, filename), 'r') as f:
            print(f"Reading: {file_path}/{filename}")
            csv_reader = csv.reader(f)
            for line in csv_reader:
                # print(f"Read: {line}")
                if line:  # Only proceed if line is not empty
                    fen = line[0]
                    outcome = int(line[1])
                    if outcome == -2:
                        continue
                    board = chess.Board(fen=fen)
                    engine_move = evaluate_board(board, limit=limit)[1]
                    if not board.is_legal(engine_move):
                        continue
                    board.push(engine_move)
                    if int(outcome / 2) == 0:
                        outcome = -1
                    post_label = -int(tb_probe_wdl_ab(board, alpha=-outcome))
                    board.pop()
                    if int(outcome / 2) > int(post_label / 2):
                        engine_errors.append((fen, engine_move.uci()))
        with open(os.path.join(fen_path, filename), 'w') as f:
            for fen in engine_errors:
                assert isinstance(fen, tuple)
                f.write(f"{fen[0]}, {fen[1]}, \n")
