import re
import ast
import chess
from typing import Dict, Tuple, List

from .exceptions import ParseException, IllegalMoveException

# =================================================
# General Functionality
# =================================================
def extract_solution(text: str) -> str:
    """ Extracts text between <answer> and </answer> tags, trims it, and returns it. """
    matches = re.findall(r"<answer>(.*?)</answer>", text, re.DOTALL)
    if not matches:
        raise ParseException("No valid pair of <answer> tags found.")
    extracted = matches[-1].strip()
    return extracted


def coerce_response(text: str, task_type: str, info: Dict = None, **kwargs) -> str:
    """ Given an output generated by an LLM, coerces it to be in a more workable format. """
    try:
        processed_text = ast.literal_eval(text)
    except:
        processed_text = text

    # Need output to be a single string (no list)
    if task_type == 'choose_from_n' or task_type == 'predict_singlemove' or task_type == 'predict_in_list':
        if isinstance(processed_text, str):
            processed_text = _stringify_move(processed_text)
        elif isinstance(processed_text, list):
            if len(processed_text) == 1:
                processed_text = _stringify_move(processed_text[0])
            elif len(processed_text) > 1:
                raise ParseException("Output is a list with multiple elements.")
        else:
            raise ParseException("Output is not in the correct format.") 
    
    # Need to coerce to a list of strings
    elif task_type == 'produce_list':
        if isinstance(processed_text, list):
            # Process each element to ensure all elements are strings
            processed_text = [_stringify_move(x) for x in processed_text]
        elif isinstance(processed_text, str):
            processed_text = re.split(r'[\s,]+', processed_text.strip()) if processed_text.strip() else []
            processed_text = _coerce_string_list(processed_text)
        else:
            raise ParseException("Output is not a list.")
    
    # Need to coerce a list of tuples (position / move, piece)
    elif task_type == "hallucination":
        extracted = extract_solution(text)
        return _coerce_hallucinations(extracted, info['board'])
        
    # Need to coerce list into a dict of key: bool
    elif task_type == "reasoning_strategy":
        extracted = extract_solution(text)
        return _coerce_dict_bool(extracted)

    else:
        raise ValueError(f"Unknown eval type: {task_type}")
    
    return processed_text


def parse_fen(fen: str) -> Dict:
    parts = fen.strip().split(maxsplit=5)
    if len(parts) != 6:
        raise ValueError("FEN must contain exactly six space-separated fields.")

    board, color, castling, en_passant, halfmove, fullmove = parts

    return {
        'original_fen'         : fen.strip(),
        'board_placement'      : board,
        'active_color'         : color,
        'castling_availability': castling,
        'en_passant_target'    : en_passant,
        'halfmove_clock'       : int(halfmove),
        'fullmove_number'      : int(fullmove)
    }


def pqt_extract_ground_truth(answer, task_type):
    if task_type == "predictmove":
        return ast.literal_eval(answer)
    elif task_type == "bestmove" or task_type == "worstmove":
        return ast.literal_eval(answer)
    elif task_type == "legalmoves":
        return ast.literal_eval(answer)
    else:
        raise ValueError(f"Task type: {task_type} is undefined.")


# ==================================================
# Helper Functions
# ==================================================
def _stringify_move(move: str) -> str:
    return ''.join(c for c in move if c.isalnum())    

def _coerce_string_list(items: list[str]) -> list[str]:
    pattern = re.compile(r'^[a-zA-Z]\d(?:[a-zA-Z]\d)?$')
    filtered = [item for item in items if pattern.match(item)]
    
    if not filtered:
        raise ParseException("No valid items found in input.")
    
    return filtered

def _coerce_dict_bool(items: str) -> Dict[str, int]:
    """
    Parse a dict-like string of reasoning-strategy flags and return
    {key: 0|1}.  Accepts 0/1, True/False, "0"/"1", "true"/"false",
    and also tuples like (True, "explanation").  Collects all problems
    before raising ParseException.
    """
    allowed_keys = {
        "Enumeration",
        "Tree Search",
        "Backtracking",
        "Self Correction",
        "Subgoal Setting",
        "Verification",
    }

    # --- literal-eval --------------------------------------------------------
    try:
        parsed = ast.literal_eval(items)
    except Exception as e:
        raise ParseException(f"Failed to parse input as a dictionary: {e}")

    if not isinstance(parsed, dict):
        raise ParseException("Parsed input is not a dictionary.")

    # --- validate & coerce ---------------------------------------------------
    errors, result = [], {}

    for key, value in parsed.items():
        # key check
        if key not in allowed_keys:
            errors.append(f"Invalid key: '{key}' (allowed: {sorted(allowed_keys)})")
            continue  # still inspect value to collect all errors

        # look only at first element if tuple / list
        v = value[0] if isinstance(value, (tuple, list)) else value

        # map to 0 / 1
        if isinstance(v, str):
            v = v.strip().lower()
            norm = 1 if v in {"1", "true"} else 0 if v in {"0", "false"} else None
        elif isinstance(v, (bool, int)):
            norm = 1 if v in {1, True} else 0 if v in {0, False} else None
        else:
            norm = None

        if norm is None:
            errors.append(
                f"Invalid value for key '{key}': {value!r} "
                "(must be 0/1, '0'/'1', True/False)"
            )
        else:
            result[key] = norm

    if errors:
        raise ParseException("Errors in input:\n" + "\n".join(errors))

    return result



# ==================================================
# LLM Parser Helpers
# ==================================================
def _coerce_hallucinations(items: str, board: str) -> Dict[str, float | list]:
    """
    Validate an LLM-generated list for the hallucination-detection task.

    • Tuple  («square», «colour piece»):  hallucination ⇢ square either empty
      or holds a different colour / piece.  Example: ("e4", "black bishop").
    • String («UCI move»):               hallucination ⇢ move is illegal.

    The returned dict gives accuracies and every hallucination encountered.
    If *any* element is malformed two error classes are collected and the
    function raises ``ParseException`` so the caller can re-prompt.
    """
    # ------ Initial literal eval and default setups ------
    try:
        data = ast.literal_eval(items)
    except Exception as exc:
        raise ParseException(
            "Output inside <answer> tags isn't a valid Python literal list. We are unable to use 'ast.literal_eval' on the text within the answer tags."
        )
    if not isinstance(data, list):
        raise ParseException("Expected a Python list inside <answer> tags.")

    bd       = chess.Board(board)  # FEN → Board
    piece_nm = {chess.PAWN: "pawn", chess.KNIGHT: "knight", chess.BISHOP: "bishop",
                chess.ROOK: "rook", chess.QUEEN: "queen", chess.KING: "king"}
    sq_pat   = re.compile(r"^[a-h][1-8]$")

    # ------ Track stats ------
    piece_tot = move_tot = piece_hit = move_hit = 0
    hallucinations: list = []

    err_parse: list[str] = []  # tuples / moves that fail validation
    # ------ Actual parsing ------
    for el in data:
        # -------- (square, colour piece) -----------------------------------
        if isinstance(el, tuple):
            piece_tot += 1
            if len(el) != 2 or not all(isinstance(x, str) for x in el):
                err_parse.append(f"Tuple not (str, str): {el!r}")
                continue

            sq      = _stringify_move(el[0].lower())
            valid_sq = sq_pat.match(sq)
            claim   = el[1].lower().strip()
            
            if valid_sq:
                pc = bd.piece_at(chess.square(ord(sq[0]) - 97, int(sq[1]) - 1))
                # bd.piece_at returns a Piece or None
                if pc:
                    actual = f"{'white' if pc.color else 'black'} {piece_nm[pc.piece_type]}"
                    ok = (actual == claim)
                    if ok:
                        piece_hit += 1
                        continue # Skip adding to hallucinations

            else:
                err_parse.append(f"The first element of {el} isn't a valid chess board square position (e.g., e4).")
                continue
                
            # We'll exit before this if an error parse or if correct
            hallucinations.append(el)

        # --------------------- UCI move string -----------------------------
        elif isinstance(el, str):
            move_tot += 1
            mv = _stringify_move(el.lower())

            # Ensure UCI notation: must be 4 or 5 alphanumeric characters (e.g. e2e4, e7e8q)
            if not re.fullmatch(r"^[a-h][1-8][a-h][1-8][qrbn]?$", mv):
                err_parse.append(f"Move '{el}' is not in valid UCI notation.")
                continue

            legal = chess.Move.from_uci(mv) in bd.legal_moves
            move_hit += legal
            if not legal:
                hallucinations.append(el)

    # ------ Reprompt if any errors ------
    if err_parse:
        parts = []
        if err_parse:
            parts.append("Parsing errors you need to fix: " + "; ".join(err_parse))
        # print(f"REPROMPT ERROR: {parts}")
        raise ParseException("".join(parts))

    # ------ Otherwise return our stats / hallucinations ------
    return {
        "Count: Moves Checked":  move_tot,
        "Count: Moves Correct":  move_hit,
        "Count: Pieces Checked": piece_tot,
        "Count: Pieces Correct": piece_hit,
        "Count: Hallucinations": len(hallucinations)
    }