import itertools
import json
import math
import random
import os
from tqdm import tqdm

###########################
# 1. Helper Functions
###########################

def check_winner(board):
    """
    Check if there is a winner on the given 3x3 board.
    board is a list of length 9:
      0..8 positions: 0 = empty, 1 = P1, 2 = P2
    Returns:
      winner (0 if none, 1 if P1, 2 if P2),
      is_terminal (True if someone won or board is full, else False)
    """
    lines = [
        # Rows
        (0,1,2), (3,4,5), (6,7,8),
        # Columns
        (0,3,6), (1,4,7), (2,5,8),
        # Diagonals
        (0,4,8), (2,4,6)
    ]
    
    for (a,b,c) in lines:
        if board[a] != 0 and board[a] == board[b] == board[c]:
            return board[a], True  # board[a] is the winner
    
    # If no winner, check for draw (i.e., board is full)
    if 0 not in board:
        return 0, True  # Draw, terminal
    
    return 0, False  # No winner, not terminal

def get_token_for_move(player, cell_idx):
    """
    Given the current player (1 or 2) and the cell index (0..8),
    return the token ID as specified:
      - P1 tokens = 1..9
      - P2 tokens = 10..18
    cell_idx ∈ [0..8] => cell positions in row-major order:
       (0,0)->0, (0,1)->1, (0,2)->2, (1,0)->3, ...
    """
    if player == 1:
        return cell_idx + 1  # 0->1, 1->2, ... 8->9
    else:
        return cell_idx + 10  # 0->10, 1->11, ... 8->18

def get_cell_idx_from_token(token):
    """
    Inverse of get_token_for_move.
    If token is in [1..9], that's P1. If in [10..18], that's P2.
    Returns (player, cell_idx).
    """
    if 1 <= token <= 9:
        return (1, token - 1)
    elif 10 <= token <= 18:
        return (2, token - 10)
    else:
        raise ValueError(f"Invalid token: {token}")

def apply_move(board, player, cell_idx):
    """
    Returns a new board after placing 'player' in cell_idx (0..8).
    board is a list of length 9 (copy to avoid in-place if needed).
    """
    new_board = board.copy()
    new_board[cell_idx] = player
    return new_board

def board_to_ascii(board):
    """
    Create a simple ASCII 3x3 representation of the board.
    We use 'X' for P1, 'O' for P2, '.' for empty.
    board is a list of 9 (0=empty,1=P1,2=P2).
    """
    symbols = {0:'.', 1:'X', 2:'O'}
    rows = []
    for r in range(3):
        row_syms = [symbols[board[r*3 + c]] for c in range(3)]
        rows.append(" ".join(row_syms))
    return "\n".join(rows)

def board_to_text_instruction(board):
    """
    (Optional) Natural language description of the board state
    or "instruction" style text.
    Example:
      "Top row: X, O, empty. Middle row: X, empty, O. ..."
    """
    symbols = {0:'empty', 1:'X', 2:'O'}
    lines = []
    for r in range(3):
        row_syms = [symbols[board[r*3 + c]] for c in range(3)]
        line = f"Row {r}: {', '.join(row_syms)}."
        lines.append(line)
    return " ".join(lines)

def generate_symmetries(board):
    """
    Generate all 8 symmetrical transformations for a 3x3 board:
      - rotations (0, 90, 180, 270)
      - horizontal flip for each rotation
    Return them as a list of lists (each length-9).
    This helps in grouping mirror states or in canonicalization.
    """
    def rotate_90(b):
        # b[r*3 + c] -> new_b[c*3 + (2-r)]
        new_b = [0]*9
        for r in range(3):
            for c in range(3):
                new_b[c*3 + (2-r)] = b[r*3 + c]
        return new_b

    def flip_horizontal(b):
        # Flip across horizontal axis: b[r*3 + c] -> new_b[(2-r)*3 + c]
        new_b = [0]*9
        for r in range(3):
            for c in range(3):
                new_b[(2-r)*3 + c] = b[r*3 + c]
        return new_b

    boards = []
    current = board

    # Generate rotations
    for _ in range(4):
        boards.append(current)
        current = rotate_90(current)
    
    # For each rotation, also flip horizontally
    flipped_versions = [flip_horizontal(b) for b in boards]
    boards.extend(flipped_versions)

    # boards now contains up to 8 transformations
    return boards

def canonical_symmetry_id(board):
    """
    Returns a 'canonical' representation as a string or tuple
    by picking the lexicographically smallest version
    among all symmetries. 
    Used to group symmetrical states together.
    """
    symmetries = generate_symmetries(board)
    # Convert each to a tuple for easy comparison
    as_tuples = [tuple(s) for s in symmetries]
    # Pick lexicographically minimal
    canonical = min(as_tuples)
    return canonical

###########################
# Best move calculation helpers
###########################

def get_token_for_move(player, cell_idx):
    """
    Given the current player (1 or 2) and the cell index (0..8),
    return the token ID as specified:
      - P1 tokens = 1..9
      - P2 tokens = 10..18
    cell_idx ∈ [0..8] => cell positions in row-major order:
       (0,0)->0, (0,1)->1, (0,2)->2, (1,0)->3, ...
    """
    if player == 1:
        return cell_idx + 1  # 0->1, 1->2, ... 8->9
    else:
        return cell_idx + 10  # 0->10, 1->11, ... 8->18

def check_winner(board):
    """
    Check if there is a winner on the given 3x3 board.
    board is a list of length 9:
      0..8 positions: 0 = empty, 1 = P1, 2 = P2
    Returns:
      winner (0 if none, 1 if P1, 2 if P2),
      is_terminal (True if someone won or board is full, else False)
    """
    lines = [
        # Rows
        (0,1,2), (3,4,5), (6,7,8),
        # Columns
        (0,3,6), (1,4,7), (2,5,8),
        # Diagonals
        (0,4,8), (2,4,6)
    ]
    
    for (a,b,c) in lines:
        if board[a] != 0 and board[a] == board[b] == board[c]:
            return board[a], True  # board[a] is the winner
    
    # If no winner, check for draw (i.e., board is full)
    if 0 not in board:
        return 0, True  # Draw, terminal
    
    return 0, False  # No winner, not terminal

def minimax_all_moves_depth_sensitive(board, player, depth=0, alpha=-math.inf, beta=math.inf):
    """
    Minimax that prefers faster wins and slower losses.
    """
    winner, is_terminal = check_winner(board)
    if is_terminal:
        if winner == 1:
            # A win is worth 10, minus the depth. Faster win = higher score.
            return 10 - depth, []
        elif winner == 2:
            # A loss is worth -10, plus the depth. Slower loss = "less bad" score.
            return -10 + depth, []
        else:
            return 0, [] # A draw is always 0.

    empty_cells = [i for i, cell in enumerate(board) if cell == 0]
    
    if player == 1:  # Maximizing player
        max_eval = -math.inf
        best_moves = []
        for move in empty_cells:
            new_board = board[:]
            new_board[move] = 1
            # Increment depth for the recursive call
            evaluation, _ = minimax_all_moves_depth_sensitive(new_board, 2, depth + 1, alpha, beta)
            
            if evaluation > max_eval:
                max_eval = evaluation
                best_moves = [move]
            elif evaluation == max_eval:
                best_moves.append(move)
            
            alpha = max(alpha, evaluation)
            if beta <= alpha:
                break
        return max_eval, best_moves
    else:  # Minimizing player
        min_eval = math.inf
        best_moves = []
        for move in empty_cells:
            new_board = board[:]
            new_board[move] = 2
            # Increment depth for the recursive call
            evaluation, _ = minimax_all_moves_depth_sensitive(new_board, 1, depth + 1, alpha, beta)
            
            if evaluation < min_eval:
                min_eval = evaluation
                best_moves = [move]
            elif evaluation == min_eval:
                best_moves.append(move)

            beta = min(beta, evaluation)
            if beta <= alpha:
                break
        return min_eval, best_moves

def find_all_best_moves(row):
    """
    Wrapper function to be used with pandas apply.
    This now returns a list of all optimal moves.
    """
    if row['is_terminal']:
        return [] # Return an empty list for terminal states

    # Determine whose turn it is
    p1_moves = row['board'].count(1)
    p2_moves = row['board'].count(2)
    current_player = 2 if p1_moves > p2_moves else 1

    _, best_move_indices = minimax_all_moves_depth_sensitive(row['board'], current_player)
    
    if best_move_indices:
        # Convert all best move indices to their corresponding tokens
        return sorted([get_token_for_move(current_player, move_idx) for move_idx in best_move_indices])
    return []


###########################
# 2. Generate All Valid States
###########################

def explore_all_sequences():
    """
    We explore all possible Tic Tac Toe sequences from empty board to a terminal state.
    Return a list of (board, move_sequence).
    Where move_sequence is the list of tokens used to get there.
    We'll also store *intermediate states*. 
    """
    initial_board = [0]*9  # 9 cells, 0=empty
    results = []
    
    # We'll do DFS or BFS with recursion.
    def backtrack(board, move_sequence, current_player, results):
        winner, terminal = check_winner(board)
        # Record this (board, move_sequence) as a valid state
        results.append((board, move_sequence))

        if terminal:
            # No more moves possible
            return
        
        # Explore next moves
        for cell_idx in range(9):
            if board[cell_idx] == 0:  # empty => legal move
                token = get_token_for_move(current_player, cell_idx)
                new_board = apply_move(board, current_player, cell_idx)
                next_player = 1 if current_player == 2 else 2
                backtrack(new_board, move_sequence + [token], next_player, results)
    
    backtrack(initial_board, [], 1, results)
    return results

###########################
# 3. Building the Dataset
###########################

def build_tictactoe_dataset():
    """
    Returns a list of dictionaries. Each dictionary includes:
      - state_id
      - board (as a list of ints)
      - ascii_board
      - text_instruction
      - move_sequence (the tokens leading to this state)
      - is_terminal
      - winner
      - next_legal_moves
      - canonical_symmetry_id (to group symmetrical states)
      - text_instruction_alt (Replace X and O with Y and +)
      - ascii_board_alt
      - best_moves (alpha beta pruning with minimax to find the best set of moves)
    """
    all_entries = []
    
    # (board, move_sequence) pairs, but there will be duplicates 
    # because the same board can be reached via different paths (less likely in TTT, but possible).
    # We'll use a set or dict to handle duplicates.
    raw_states = explore_all_sequences()
    
    # We only want unique boards with the earliest move sequence 
    # that produced it. Or we can store *all* sequences if we prefer.
    # For comprehensiveness, let's store all unique (board, move_sequence) pairs.
    # Then we might deduplicate boards in a separate step if needed.
    
    # But if the user truly wants all partial states in *all* ways they can be reached, 
    # you can skip deduplication. For now, let's keep them all, 
    # but also keep a dictionary keyed by board to track the first time it was seen.
    
    # A dictionary to store a list of (move_sequence) for each unique board
    board_dict = {}
    for (b, seq) in raw_states:
        b_tuple = tuple(b)
        if b_tuple not in board_dict:
            board_dict[b_tuple] = []
        board_dict[b_tuple].append(seq)
    
    # Now let's build the final dataset
    state_id_counter = 0
    dataset = []
    
    for board_tuple, sequences in tqdm(board_dict.items()):
        board_list = list(board_tuple)
        winner, terminal = check_winner(board_list)
        
        # Identify next legal moves
        if not terminal:
            # find which player's turn it is
            # Count how many 1's vs 2's
            p1_count = sum(1 for x in board_list if x == 1)
            p2_count = sum(1 for x in board_list if x == 2)
            current_player = 1 if p1_count == p2_count else 2
            next_moves = []
            for cell_idx in range(9):
                if board_list[cell_idx] == 0:
                    token = get_token_for_move(current_player, cell_idx)
                    next_moves.append(token)
        else:
            next_moves = []
        
        # Mark winner with 0=No winner, 1=P1, 2=P2
        # is_terminal indicates a final or not
        # We can store sequences as well. For comprehensiveness, 
        # let's store *all* sequences that lead to this board (some might be duplicates).
        ascii_rep = board_to_ascii(board_list)
        text_rep = board_to_text_instruction(board_list)
        
        # canonical symmetry ID
        can_id = canonical_symmetry_id(board_list)
        
        entry = {
            "state_id": state_id_counter,
            "board": board_list,  # e.g. [1, 2, 0, ...]
            "ascii_board": ascii_rep,
            "text_instruction": text_rep,
            "move_sequences": sequences,  # multiple sequences that produce this state
            "is_terminal": terminal,
            "winner": winner,  # 0=no winner/draw, 1=P1, 2=P2
            "next_legal_moves": next_moves,
            "canonical_symmetry_id": can_id,
            "text_instruction_alt": text_rep.replace('X', '+').replace('O', 'Y'),
            "ascii_board_alt": ascii_rep.replace('X', '+').replace('O', 'Y')
        }
        
        # Find best moves
        best_moves = find_all_best_moves(entry)
        
        entry["best_moves"] = best_moves
        dataset.append(entry)
        state_id_counter += 1
    
    return dataset

###########################
# 4. Split the Dataset
###########################

def split_dataset(dataset, test_ratio=0.2, skip_symmetries=False):
    """
    Demonstrates one approach to splitting into train/test sets.
    If skip_symmetries=True, we only pick one from each symmetry group
    to place in test or train, so that symmetrical states do not bleed across sets.
    """
    
    if skip_symmetries:
        # group by canonical_symmetry_id
        symmetry_dict = {}
        for d in dataset:
            can_id = tuple(d["canonical_symmetry_id"])
            if can_id not in symmetry_dict:
                symmetry_dict[can_id] = []
            symmetry_dict[can_id].append(d)
        
        # We’ll produce a single random assignment of each group
        # to train or test
        all_groups = list(symmetry_dict.values())
        random.shuffle(all_groups)
        
        test_size = int(len(all_groups) * test_ratio)
        test_groups = all_groups[:test_size]
        train_groups = all_groups[test_size:]
        
        train_data = []
        test_data = []
        for g in train_groups:
            train_data.extend(g)
        for g in test_groups:
            test_data.extend(g)
        
        return train_data, test_data
    else:
        # Simple random split ignoring symmetries
        shuffled = dataset[:]
        random.shuffle(shuffled)
        test_size = int(len(shuffled) * test_ratio)
        test_data = shuffled[:test_size]
        train_data = shuffled[test_size:]
        return train_data, test_data

def split_dataset_3way(dataset, test_ratio=0.15, val_ratio=0.15, skip_symmetries=False):
    """
    Splits `dataset` into train, validation, and test sets.
    
    :param dataset: list of items to split
    :param test_ratio: Fraction of data to allocate to test. E.g., 0.15 for 15%
    :param val_ratio: Fraction of data to allocate to validation. E.g., 0.15 for 15%
    :param skip_symmetries: If True, symmetrical states are grouped and 
                            assigned together to a single split (no overlap).
    
    :return: (train_data, val_data, test_data)
    """
    
    # We'll compute final ratios out of 1.0. 
    # Make sure test_ratio + val_ratio <= 1.0 (so train has something left).
    if test_ratio + val_ratio > 1.0:
        raise ValueError("test_ratio + val_ratio must be <= 1.0")
    
    if skip_symmetries:
        # --- Group by canonical_symmetry_id ---
        symmetry_dict = {}
        for d in dataset:
            can_id = tuple(d["canonical_symmetry_id"])
            if can_id not in symmetry_dict:
                symmetry_dict[can_id] = []
            symmetry_dict[can_id].append(d)
        
        # We’ll produce a single random assignment of each group.
        all_groups = list(symmetry_dict.values())
        random.shuffle(all_groups)
        
        # figure out how many groups go to test, val, train
        total_groups = len(all_groups)
        test_count = int(total_groups * test_ratio)
        val_count = int(total_groups * val_ratio)
        
        test_groups = all_groups[:test_count]
        val_groups = all_groups[test_count:test_count + val_count]
        train_groups = all_groups[test_count + val_count:]
        
        # Flatten out groups
        train_data = []
        val_data = []
        test_data = []
        for g in test_groups:
            test_data.extend(g)
        for g in val_groups:
            val_data.extend(g)
        for g in train_groups:
            train_data.extend(g)
        
        return train_data, val_data, test_data
    else:
        # --- Simple random split ignoring symmetries ---
        shuffled = dataset[:]
        random.shuffle(shuffled)
        N = len(shuffled)
        
        test_size = int(N * test_ratio)
        val_size = int(N * val_ratio)
        
        test_data = shuffled[:test_size]
        val_data = shuffled[test_size:test_size + val_size]
        train_data = shuffled[test_size + val_size:]
        
        return train_data, val_data, test_data


###########################
# 5. Main / Execution
###########################

if __name__ == "__main__":
    # 1) Build the dataset
    dataset = build_tictactoe_dataset()
    
    # 2) Decide your ratios
    test_ratio = 0.15
    val_ratio = 0.15
    
    # 3) Split into train/val/test
    train_set, val_set, test_set = split_dataset_3way(
        dataset,
        test_ratio=test_ratio,
        val_ratio=val_ratio,
        skip_symmetries=True
    )
    
    # 4) Print stats
    print(f"Total states in dataset: {len(dataset)}")
    print(f"Training set size: {len(train_set)}")
    print(f"Validation set size: {len(val_set)}")
    print(f"Test set size: {len(test_set)}")
    
    # 5) Save as JSON
    BASE_SAVE_PATH = "/mnt/shared/data/stlm-logic"
    dataset_path = os.path.join(BASE_SAVE_PATH, "datasets")
    
    # Ensure the directory exists
    os.makedirs(dataset_path, exist_ok=True)
    
    dataset_train_path = os.path.join(dataset_path, "tictactoe_train.json")
    dataset_val_path   = os.path.join(dataset_path, "tictactoe_val.json")
    dataset_test_path  = os.path.join(dataset_path, "tictactoe_test.json")
    dataset_full_path  = os.path.join(dataset_path, "tictactoe_dataset.json")
    
    print(f"Sample state:\n", json.dumps(dataset[0], indent=2))
    
    with open(dataset_full_path, "w") as f:
        json.dump(dataset, f, indent=2)

    with open(dataset_train_path, "w") as f:
        json.dump(train_set, f, indent=2)
    with open(dataset_val_path, "w") as f:
        json.dump(val_set, f, indent=2)
    with open(dataset_test_path, "w") as f:
        json.dump(test_set, f, indent=2)

    print("Dataset generation complete!")
