from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import textdistance
import numpy as np
import random


class Reward:

    def __init__(self):

        self.reward = {}
        self.diversity = []

    def get_reward(self, traj_id):
        return self.reward[traj_id]
    
    def update_self_reward(self, traj_id, reward):
        self.reward[traj_id] = reward

    def update_default(self, traj_id):
        self.reward[traj_id] = 1

    def get_path_reward(self, traj_id, traj_path):
        # return 1
        # 1. Backpropagation from task reflection node
        rewards = []
        local_diversity = set()
        for traj in traj_path[::-1]:
            local_traj_id = traj.get_node_id()
            if local_traj_id == traj_id:
                continue
            if traj.agent_state == CorvusActionValue.ToolPlanning:
                break
            
            if traj.agent_state == CorvusActionValue.ToolSelection:
                if 'api_json' in traj.get_memory():
                    local_diversity.add(json.dumps(traj.get_memory()['api_json']['name']) + json.dumps(traj.get_memory()['parameter']))
                else:
                    local_diversity.add("default")
            rewards.append(self.reward[local_traj_id])
        
        count = 1
        for k, v in Counter(self.diversity).items():
            if k in local_diversity:
                count += v
        
        return sum(rewards) + math.sqrt(1 / count)

    def update_start_reward(self, traj_id):
        self.reward[traj_id] = 1

    def update_planning_reward(self, traj_id):
        self.reward[traj_id] = 1

    def update_selection_reward(self, traj_id, state, message, api_json, parameter):
        # 1. check ground truth: reward +1
        local_reward = 0
        z = 0
        
        if 'golden' not in message:
            self.reward[traj_id] = 1
            return
        
        for api in message['golden']:
            golden_name = api_util.change_name(api_util.standardize(api['api']['api_name']))
            z += 1
            if golden_name == api_json['name']:
                local_reward += 1
                for k, v in api['parameter'].items():
                    if k in parameter and parameter[k] == v:
                        local_reward += 1
                    z += 1
        self.diversity.append(json.dumps(api_json['name']) + json.dumps(parameter))
        self.reward[traj_id] = (local_reward / z)

    def update_execuation(self, traj_id, sucess):
        self.reward[traj_id] = sucess

    def update_tool_reflection_reward(self, traj_id, api_traj, tool_traj, status):
        
        ref_reward = self.reward[api_traj.get_node_id()] + self.reward[tool_traj.get_node_id()]
        if ref_reward == 2 and  status == 1:
            self.reward[traj_id] = 0
        elif ref_reward == 0 and  status != 1:
            self.reward[traj_id] = 0
        else:
            self.reward[traj_id] = ref_reward / 2

    def update_task_reflection_reward(self, traj_id, message, traj_path, status):
        
        pred_api = set()
        for traj in traj_path:
            pred_api.add(traj.get_memory()['api_json']["name"])

        local_reward = 0
        z = 0
        
        if 'golden' not in message:
            self.reward[traj_id] = 1
            return
        
        for api in message['golden']:
            golden_name = api_util.change_name(api_util.standardize(api['api']['api_name']))
            z += 1
            if golden_name in pred_api:
                local_reward += 1

        if local_reward/z == 0 and status != 1:
            self.reward[traj_id] = 0
        else:
            self.reward[traj_id] = local_reward/z
    
    def ranking_by_diversity(self, traj_leaves):
        traj_strings = []
        for traj in traj_leaves:
            traj_action = []
            for api_action in self.retrieve_api(traj.get_node_id()):
                if len(api_action.get_memory()):
                    api = api_action.get_memory()['api']["ID"]
                    param = api_action.get_memory()['parameter']
                    traj_action.append("{}: {}".format(api, param))
            for ref_action in self.retrieve_reflection(traj.get_node_id()):
                if len(ref_action.get_memory()):
                    status = ref_action.get_memory()['reflection']['status']
                    traj_action.append(str(status))
            traj_strings.append("".join(traj_action))
        scores = calculate_diversity_with_jaccard(traj_strings)
        return [index for index, value in sorted(enumerate(scores), key=lambda x: x[1])]

    def ranking_by_reward(self, traj_leaves):
        
        traj_reward = []
        for traj in traj_leaves:
            traj_id = traj.get_node_id()
            reward = self.get_path_reward(traj_id, self._retrieve_path(traj_id))
            traj_reward.append(reward)
        print("!!!", traj_reward)
        
        if len(traj_reward) > 0 and all(x == traj_reward[0] for x in traj_reward):
            indices = list(range(len(traj_reward)))
            random.shuffle(indices)
            return indices
        else:
            return [index for index, value in sorted(enumerate(traj_reward), key=lambda x: x[1])]



def calculate_diversity_with_cosine(strings):
    vectorizer = TfidfVectorizer().fit_transform(strings)
    vectors = vectorizer.toarray()
    cosine_sim = cosine_similarity(vectors)
    diversity_scores = 1 - np.mean(cosine_sim, axis=1)
    return diversity_scores

def calculate_diversity_with_jaccard(strings):
    num_strings = len(strings)
    diversity_scores = np.zeros(num_strings)
    for i in range(num_strings):
        for j in range(num_strings):
            if i != j:
                diversity_scores[i] += textdistance.jaccard(strings[i], strings[j])
        diversity_scores[i] /= (num_strings - 1)
    return diversity_scores