import json
from typing import Dict
import dspy
import os
import re
from textwrap import dedent
from os.path import join as pjoin
import pandas as pd
import tqdm
from dspy.evaluate import Evaluate
import chess


from dspy.teleprompt import LabeledFewShot, BootstrapFewShotWithRandomSearch, MIPROv2

def load_text(path):
    with open(path, 'r') as f:
        return dedent(f.read())


def get_current_file_directory():
    return os.path.dirname(os.path.abspath(__file__))


def get_parent_directory():
    return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))


# ['calculation', 'pattern_recognition', 'informativeness', 'rating_appropriate', 'quality', 'fun']

class RecommendedPuzzle(dspy.Signature):
    """Judge a puzzle's quality and return scores for different aspects."""
    persona = dspy.InputField(description="The persona of the annotator.")
    annotation_guideline = dspy.InputField(
        description="The complete documentation on how to annotate the puzzle.")
    user_puzzle_elo = dspy.InputField(description="The user's current Puzzle ELO on chess.com.")
    puzzle_board = dspy.InputField(description="The chess board representation of the puzzle that we are recommending to the user. We are judging the quality of this specific puzzle.")
    scores = dspy.OutputField(
        description="You should try to judge these puzzles according to your unique expertise and experience. Provide scores for the following aspects in valid JSON format:\n"
        "- calculation: Rate 1-100 on whether this puzzle requires thinking through the entire solution before the first move.\n"
        "- pattern_recognition: Rate 1-100 on whether this puzzle teaches useful or unusual motifs.\n"
        "- informativeness: Rate 1-100 on whether the solution ensures the player understands the puzzle's point.\n"
        "- rating_appropriate: Rate 1-100 on whether this puzzle is appropriate for the user's skill level.\n"
        "- quality: Rate 1-100 on the overall quality of the puzzle, regardless of ELO level.\n"
        "- fun: Rate 1-100 on how enjoyable and fun this puzzle would be for players at the target ELO.")
    # additional_comments = dspy.OutputField(description="Please write a 1-2 sentence justification of your ratings. Be brief.")


class Judge(dspy.Module):
    def __init__(self, anno_name='anno_1'):
        self.all_persona_data = json.load(open(pjoin(get_parent_directory(), "data/llm_data/annotator_info.json")))
        self.annotation_guideline = load_text(pjoin(get_current_file_directory(), f'prompt/guideline.txt'))
        self.persona_data = self.all_persona_data[anno_name]
        self.anno_name = anno_name
        self.persona_sys_prompt = self.generate_persona_description()

        # either dspy.Predict or dspy.ChainOfThought
        self.rater = dspy.ChainOfThought(RecommendedPuzzle)

    def generate_persona_description(self):
        key, value = self.anno_name, self.persona_data
        # Extract title and rating from the value
        title = value.get('Title', 'N/A')
        rating = value.get('Elo (USCF)', 'unknown')

        # Handle cases where title or rating might be missing or 'N/A'
        title_str = f"the title {title}" if title != 'N/A' else "no official title"
        rating_str = f"a USCF rating of {rating}" if rating != 'unknown' else "an unknown USCF rating"

        # Generate persona description
        description = (
            f"You are a chess player with {title_str} and {rating_str}. "
            "You are asked to rate chess puzzles on a few different criteria."
            "Please read the annotation guideline carefully, combined with your expert-level insights, provide quality ratings on these chess puzzles."
        )

        return description
    
    def forward(self, user_elo, puzzle_board):
        # persona=self.persona_sys_prompt, 
        response = self.rater(annotation_guideline=self.annotation_guideline, 
                              user_puzzle_elo=user_elo, puzzle_board=puzzle_board)
        return response
    
    @staticmethod
    def extract_score_from_response(response):
        try:
            # First, attempt to parse the entire response as JSON
            scores_text = response['scores']
            
            # Check if the scores_text starts with a code block
            if scores_text.startswith('```json'):
                json_content = scores_text.split('```json')[1].split('```')[0].strip()
                scores = json.loads(json_content)
            else:
                # If not a code block, try parsing the entire text as JSON
                scores = json.loads(scores_text)
            
            values = {
                'calculation_score': int(scores.get('calculation', 1)),
                'pattern_recognition_score': int(scores.get('pattern_recognition', 1)),
                'informativeness_score': int(scores.get('informativeness', 1)),
                'rating_appropriate_score': int(scores.get('rating_appropriate', 1)),
                'quality_score': int(scores.get('quality', 1)),
                'fun_score': int(scores.get('fun', 1))
            }
        except json.JSONDecodeError:
            # If JSON parsing fails, fall back to extracting individual scores
            values = {
                'calculation_score': Judge.extract_score(scores_text, metric_name="calculation"),
                'pattern_recognition_score': Judge.extract_score(scores_text, metric_name="pattern_recognition"),
                'informativeness_score': Judge.extract_score(scores_text, metric_name="informativeness"),
                'rating_appropriate_score': Judge.extract_score(scores_text, metric_name="rating_appropriate"),
                'quality_score': Judge.extract_score(scores_text, metric_name="quality"),
                'fun_score': Judge.extract_score(scores_text, metric_name="fun")
            }

        allowed_ranges = {
            'calculation_score': (1, 100),
            'pattern_recognition_score': (1, 100),
            'informativeness_score': (1, 100),
            'rating_appropriate_score': (1, 100),
            'quality_score': (1, 100),
            'fun_score': (1, 100)
        }

        out_of_range_metrics = []
        for key, (min_val, max_val) in allowed_ranges.items():
            if values[key] < min_val or values[key] > max_val:
                out_of_range_metrics.append(key)

        if out_of_range_metrics:
            print("Raw response (scores out of range):", response)
            print("Metrics out of range:", ", ".join(out_of_range_metrics))
            for metric in out_of_range_metrics:
                print(f"{metric}: {values[metric]} (allowed range: {allowed_ranges[metric]})")

        # Ensure scores are within allowed ranges
        for key in values:
            min_val, max_val = allowed_ranges[key]
            values[key] = min(max(values[key], min_val), max_val)

        return values

    @staticmethod
    def extract_score(score_string, metric_name="", debug=False):
        if metric_name != "" and metric_name + ":" in score_string:
            attempt = score_string.split(metric_name + ":")
            if len(attempt) > 1:
                score_string = attempt[1]
        if "\n\n" in score_string:
            # some responses have additional fields
            score_string = score_string.split('\n\n')[0].strip()

        if debug:
            print("[DEBUG] score_string: ", score_string)

        match = re.search(r'^\d', score_string.strip())
        return int(match.group()) if match else 0

# now need a function to do few-shot learning

def load_rating_data(anno_name):
    all_data = pd.read_csv(pjoin(get_parent_directory(), "data/llm_data/train_puzzles.csv"))
    rater_data = all_data[all_data['annotator'] == anno_name]

    return rater_data

def load_test_puzzles():
    # test_puzzles_path = pjoin(get_parent_directory(), "data/test_puzzles.json")
    # test_puzzles_path = pjoin(get_parent_directory(), "data/new_test_puzzles.json")
    test_puzzles_path = pjoin(get_parent_directory(), "data/puzzles_v3.json")
    with open(test_puzzles_path, 'r') as f:
        test_puzzles = json.load(f)
    return test_puzzles

def eval_metric(true: dspy.Example, prediction, trace=None):
    try:
        mae = 0
        count = 0
        prediction_scores = Judge.extract_score_from_response(prediction)
        true_scores = json.loads(true.scores)
        for key in ['calculation', 'pattern_recognition', 'informativeness', 'rating_appropriate', 'quality', 'fun']:
            true_score = int(true_scores[key]) if true_scores[key] is not None else 0
            pred_score = prediction_scores.get(f"{key}_score", 0)
            
            # Scale scores for calculation and pattern_recognition
            if key == 'calculation':
                true_score = int(true_score * 25)  # Scale from 0-4 to 1-100
            elif key == 'pattern_recognition':
                true_score = int(true_score * 33)  # Scale from 0-3 to 1-99
            
            mae += abs(true_score - pred_score) / 100
            count += 1
        avg = mae / count if count > 0 else 0
        return -avg
    except Exception as e:
        print(f"Error in eval_metric: {e}")
        print(f"True: {true}")
        print(f"Prediction: {prediction}")
        return -10000  # Return a default value or handle the error as appropriate

def train_judge(anno_name):
    # Check if the judge is already trained
    optimized_judges_dir = 'optimized_judges'
    judge_path = os.path.join(optimized_judges_dir, f'{anno_name}_judge.pkl')
    if os.path.exists(judge_path):
        print(f"Optimized judge for {anno_name} already exists. Skipping training.")
        return
    
    judge = Judge(anno_name)

    llm = dspy.OpenAI(model="gpt-4o-mini-2024-07-18", max_tokens=16383,
                      system_prompt=judge.persona_sys_prompt)  # this is the longest context we can have
    dspy.settings.configure(lm=llm)

    data = load_rating_data(anno_name)
    examples = []
    for _, row in data.iterrows():
        fen = eval(row['puzzle_info'])['fen3']
        board = chess.Board(fen)
        # Extract additional puzzle information
        puzzle_info = eval(row['puzzle_info'])
        from_sq = puzzle_info.get('from_sq', '')
        to_sq = puzzle_info.get('to_sq', '')
        piece_name = puzzle_info.get('piece_name', '')

        # Create a description of the first move
        first_move_desc = f"The puzzle starts with {piece_name} moving from {from_sq} to {to_sq}. It's now your turn to find the best move."

        # Combine board representation with the first move description
        puzzle_description = f"Chess board position:\n{board}\n\n{first_move_desc}"

        ex = dspy.Example({
            "user_elo": eval(row["user_rating"])[0],
            "puzzle_board": puzzle_description,
            "scores": json.dumps({
                "calculation": str(int(eval(row["calculation"])[0] * 25)) if eval(row["calculation"])[0] is not None else "1",
                "pattern_recognition": str(int(eval(row["pattern_recognition"])[0] * 33)) if eval(row["pattern_recognition"])[0] is not None else "1",
                "informativeness": str(int(eval(row["informativeness"])[0]) * 20),
                "rating_appropriate": str(int(eval(row["rating_appropriate"])[0]) * 20),
                "quality": str(int(eval(row["quality"])[0]) * 20),
                "fun": str(int(eval(row["fun"])[0]) * 20)
            })
        }).with_inputs("user_elo", "puzzle_board")
        examples.append(ex)
    
    print("annotator name: ", anno_name, "number of examples: ", len(examples))
    all_mae = [eval_metric(true, true) for true in examples]
    print(f"(Sanity check) MAE on ground truth: {sum(all_mae) / len(all_mae)}")

    config = dict(max_bootstrapped_demos=2, max_labeled_demos=4, num_candidate_programs=2, num_threads=6)
    
    teleprompter = BootstrapFewShotWithRandomSearch(metric=eval_metric, **config)

    # half half split (12 examples)
    trainset, valset = examples[:6], examples[6:]
    optimized_judge = teleprompter.compile(judge, trainset=trainset, valset=valset)
    
    evaluate = Evaluate(devset=valset, metric=eval_metric, num_threads=6, display_progress=True, display_table=10)
    evaluate(optimized_judge)
    
    optimized_judges_dir = 'optimized_judges'
    if not os.path.exists(optimized_judges_dir):
        os.makedirs(optimized_judges_dir)
    
    judge.save(os.path.join(optimized_judges_dir, f'{anno_name}_judge.pkl'))

def evaluate_single_judge(anno_name):
    test_puzzles = load_test_puzzles()
    judge = Judge(anno_name)

    optimized_judges_dir = 'optimized_judges'
    judge_path = os.path.join(optimized_judges_dir, f'{anno_name}_judge.pkl')
    if os.path.exists(judge_path):
        judge.load(judge_path)
        print(f"Loaded optimized judge for {anno_name}")
    else:
        print(f"No optimized judge found for {anno_name}. Using default judge.")

    llm = dspy.OpenAI(model="gpt-4o-mini-2024-07-18", max_tokens=16383,
                      system_prompt=judge.persona_sys_prompt)  # this is the longest context we can have
    dspy.settings.configure(lm=llm)

    # Process data into Example format
    def evaluate_puzzles(judge, test_puzzles, puzzle_type):
        examples = []
        users = 0
        for user_elo, data in test_puzzles.items():
            users += 1
            if puzzle_type == 'ours':
                puzzles_to_evaluate = data[puzzle_type][:5]
            elif puzzle_type == 'bucketed_uniform':
                puzzles_to_evaluate = data[puzzle_type][:1]
            else:
                if type(data[puzzle_type]) == list:
                    puzzles_to_evaluate = data[puzzle_type][:1]
                else:
                    puzzles_to_evaluate = [data[puzzle_type]]

            for puzzle_set in puzzles_to_evaluate:
                if type(puzzle_set) == list:
                    puzzle_set = puzzle_set[0] # there's a double list
                # for puzzle_info in puzzle_set:
                puzzle_info = puzzle_set
                board = chess.Board(puzzle_info['fen3'])
                piece_name = puzzle_info['piece_name']
                from_sq = puzzle_info['from_sq']
                to_sq = puzzle_info['to_sq']

                first_move_desc = f"The puzzle starts with {piece_name} moving from {from_sq} to {to_sq}. It's now your turn to find the best move."
                puzzle_board = f"Chess board position:\n{board}\n\n{first_move_desc}"

                ex = dspy.Example({
                    "user_elo": user_elo,
                    "puzzle_board": puzzle_board
                }).with_inputs("user_elo", "puzzle_board")
                examples.append(ex)

            if users > 10:
                break 
        
        print(f"Total number of examples for {puzzle_type}: {len(examples)}")
        print(f"Total number of users: {users}")  # 1143
        
        evaluate = Evaluate(devset=examples, metric=lambda ex, pred: 1, num_threads=6, 
                            display_progress=True, display_table=10, return_outputs=True)
        score, evaluation_results = evaluate(judge)

        responses = [e[1] for e in evaluation_results]

        results = {
            'calculation': [],
            'pattern_recognition': [],
            'informativeness': [],
            'rating_appropriate': [],
            'quality': [],
            'fun': [],
            'additional_comments': []
        }

        # Process the results
        for ex, pred in zip(examples, responses):
            scores = Judge.extract_score_from_response(pred)
            for score_type in ['calculation', 'pattern_recognition', 'informativeness', 'rating_appropriate', 'quality', 'fun']:
                results[score_type].append(scores[f'{score_type}_score'])
            
        # For 'ours', group results into lists of 5
        if puzzle_type == 'ours':
            for score_type in results:
                results[score_type] = [results[score_type][i:i+5] for i in range(0, len(results[score_type]), 5)]

        print(f"Evaluation completed for {puzzle_type}.")
        return results
    
    results = {}
    for puzzle_type in ['ours', 'chosen']:  # bucketed_uniform
        results[puzzle_type] = evaluate_puzzles(judge, test_puzzles, puzzle_type)

    # save
    with open(f'{anno_name}_results.json', 'w') as f:
        json.dump(results, f)

if __name__ == '__main__':
    
    # train_judge('anno_1')
    # evaluate_single_judge('anno_1')

    # Train judges for all 8 annotators
    for i in range(1, 9):
        anno_name = f'anno_{i}'
        # print(f"Training judge for {anno_name}")
        # train_judge(anno_name)
        print(f"Evaluating judge for {anno_name}")
        evaluate_single_judge(anno_name)
        print(f"Completed training and evaluation for {anno_name}")