import random
import re
from constants import TASK_DESCRIPTIONS,TASK_DESCRIPTIONS_CN,PROMPT_TYPES,PROMPT_TYPES_CN,TASK_DIR,WCST_RULES, WCST_TOTAL_TRIALS, WCST_CORRECT_TO_CHANGE,ROLE_DESCRIPTIONS
import json
from data_logger import WCSTDataLogger

class ExperimentFlow:
    def __init__(self, chat_session, is_human, language):
        self.data_logger = self.create_data_logger()
        self.chat_session = chat_session
        self.is_human = is_human
        self.current_trial = 0
        self.total_trials = 100
        self.current_prompt = ''
        self.current_response = ''
        self.base_dir = ''
        self.info_list = []
        self.feedback_prompt = ''
        self.current_trial_data = {}
        self.language = language

    def run_trial(self):
        raise NotImplementedError("Subclasses must implement this method")

    def evaluate_response(self, response):
        raise NotImplementedError("Subclasses must implement this method")

    def generate_prompt(self, feedback):
        raise NotImplementedError("Subclasses must implement this method")

    def get_image_description(self):
        raise NotImplementedError("Subclasses must implement this method")

    def is_experiment_complete(self):
        raise NotImplementedError("Subclasses must implement this method")
    
    def create_data_logger(self):
        raise NotImplementedError("Subclasses must implement this method")

class WCSTFlow(ExperimentFlow):
    def __init__(self, chat_session, is_human, trial_number, language, rule=True, impairment_type=None):
        super().__init__(chat_session, is_human, language)
        self.total_trials = WCST_TOTAL_TRIALS
        self.current_rule = None
        self.correct_in_row = 0
        self.category_completed = 0
        self.data_logger = self.create_data_logger()
        self.base_dir = TASK_DIR["WCST"]
        self.wcst_data = self.load_wcst_data(trial_number)
        self.impairment_type = impairment_type
        if self.language == "English":
            if rule:
                self.task_description = TASK_DESCRIPTIONS["WCST"]
            else:
                self.task_description = TASK_DESCRIPTIONS["WCST_without_restriction"]
        else:
            if rule:
                self.task_description = TASK_DESCRIPTIONS_CN["WCST"]
            else:
                self.task_description = TASK_DESCRIPTIONS_CN["WCST_without_restriction"]
        self.set_new_rule()
    
    def set_new_rule(self):
        available_rules = [rule for rule in WCST_RULES if rule != self.current_rule]
        self.current_rule = random.choice(available_rules)
    
    def create_data_logger(self):
        return WCSTDataLogger()

    def load_wcst_data(self,trial_number):
        data_dir = self.base_dir + f"trial{trial_number}/" + "cards.json"
        with open(data_dir, 'r') as f:
            return json.load(f)

    def get_current_card(self,num):
        for card_name, card_data in self.wcst_data.items():
            if card_data['trialNumber'] == num + 1:
                return card_data
        raise ValueError(f"No card found for trial {num + 1}")

    def run_trial(self, presentation_mode, prompt_type, web=False):
        current_card = self.get_current_card(self.current_trial)
        if web:
            image_path = current_card['image']
            image_path = image_path.replace('./task_datasets/', '')
        else:
            image_path = current_card['image']
        if self.current_trial == 0:
            self.current_prompt = self.generate_prompt(presentation_mode, prompt_type, '')
        correct_card = current_card[self.current_rule.split('_')[0] + 'Rule']
        self.info_list = [self.current_rule,correct_card+1,self.correct_in_row]
        return image_path, self.current_prompt

    def evaluate_response(self, feedback, presentation_mode, prompt_type, is_human=False, tokens=[], is_auto=False, auto_press=False):  # tokens=0
        current_card = self.get_current_card(self.current_trial)
        correct_card = current_card[self.current_rule.split('_')[0] + 'Rule']
        applied_rules = []
        # print(repr(self.current_response))
        if is_human:
            is_correct = int(feedback)==correct_card+1
            applied_rule = int(feedback)-1
            tokens = [0,0,0]
        else:
            if is_auto:
                if auto_press:
                    is_correct = feedback
                    applied_rule = int(re.search(r'Selection: (\d+)\.?\s*', self.current_response).group(1))-1
                else:
                    is_correct = int(feedback)==correct_card+1
                    applied_rule = int(feedback)-1
            else:
                is_correct = feedback
                # applied_rule = int(re.search(r'\d+', self.current_response).group())-1
                applied_rule = int(re.search(r'Selection: (\d+)\.?\s*', self.current_response).group(1))-1
        # self.correct_in_row = self.correct_in_row + 1 if is_correct else 0

        if applied_rule == current_card['colorRule']:
            applied_rules.append('C')
        if applied_rule == current_card['shapeRule']:
            applied_rules.append('S')
        if applied_rule == current_card['numberRule']:
            applied_rules.append('N')
        
        trial_data = {
            "trial_number": self.current_trial + 1,
            "image_path": current_card['image'],
            "prompt": self.current_prompt,
            "response": self.current_response,
            "prompt_tokens": tokens[0],
            "completion_tokens": tokens[1],
            "total_tokens": tokens[2],
            # "tokens": tokens,
            "is_correct": is_correct,
            "correct_card": correct_card,
            "correct_in_row": self.correct_in_row,
            "current_rule": self.current_rule,
            "category_completed": self.category_completed,
            "applied_rules": ''.join(applied_rules),
            "impairment_type": self.impairment_type,
        }

        self.current_trial_data = trial_data.copy()
        # self.data_logger.log_trial(trial_data)
        # self.current_trial += 1
        """
        if self.correct_in_row == WCST_CORRECT_TO_CHANGE:
            self.set_new_rule()
            self.correct_in_row = 0
            self.category_completed += 1
        """
        if self.language == "English":
            feedback = "correct" if is_correct else "incorrect"
        else:
            feedback = "正确" if is_correct else "错误"
        self.current_prompt = self.generate_prompt(presentation_mode, prompt_type, feedback,False)
        if is_human:
            return feedback
        self.feedback_prompt = self.current_prompt
        return self.feedback_prompt

    def finalize_current_trial(self):
        if hasattr(self, 'current_trial_data'):
            self.correct_in_row = self.correct_in_row + 1 if self.current_trial_data['is_correct'] else 0
            self.current_trial_data['correct_in_row'] = self.correct_in_row
            self.data_logger.log_trial(self.current_trial_data)
            if self.correct_in_row == WCST_CORRECT_TO_CHANGE:
                self.set_new_rule()
                self.correct_in_row = 0
                self.category_completed += 1
            self.current_trial += 1
            del self.current_trial_data
    
    def generate_prompt(self, presentation_mode, prompt_type, feedback, gen=True):
        self.task_description
        if self.language == "English":
            prompt_text = PROMPT_TYPES[prompt_type]
        else:
            prompt_text = PROMPT_TYPES_CN[prompt_type]
        """
        if prompt_type == "STA":
            prompt_text = PROMPT_TYPES[prompt_type]
        else:
            prompt_text = PROMPT_TYPES_SEP[prompt_type]["WCST"]
        """
        if self.current_trial == 0 and gen:
            image_description = self.get_image_description(num=self.current_trial)
            if presentation_mode == "OI":
                if self.is_human:
                    prompt = f"{self.task_description}"
                else:
                    prompt = f"{self.task_description}\n{prompt_text}"
            elif presentation_mode == "OIT" or presentation_mode == "OT":
                prompt = f"{self.task_description}\n{image_description}\n{prompt_text}"
            
            if self.impairment_type and self.impairment_type in ROLE_DESCRIPTIONS:
                prompt += f"\n{ROLE_DESCRIPTIONS[self.impairment_type]}"
            
        elif not gen:
            image_description = self.get_image_description(num=self.current_trial+1)
            if presentation_mode == "OIT" or presentation_mode == "OT":
                prompt = f"Your previous selection was {feedback}. Please make your next selection. The next image can be described as:{image_description}"
            else:
                if self.language == "English":
                    prompt = f"Your previous selection was {feedback}. Please make your next selection."
                else:
                    prompt = f"您之前的选择{feedback}。请做出下一个选择。"
        
        return prompt

    def get_image_description(self,num):
        if num < self.total_trials:
            current_card = self.get_current_card(num)
            number = current_card['number']
            color = current_card['color']
            shape = current_card['shape']+' sign' if current_card['shape'] == "cross" else current_card['shape']
            shape_text = shape if number == 1 else shape + 's'
            card_name = f"{number} {color} {shape_text}"
            return f"The image shows four cards with white background at the top, the first card is with 1 single red triangle, the second card is with 2 green stars, the third card is with 3 yellow cross signs, the fourth card is with 4 blue circles. And there is one card at the bottom left with {card_name}."
        else:
            return None
    
    def is_experiment_complete(self):
        return self.current_trial >= self.total_trials
