from collections import Counter
import math
import numpy as np
from typing_extensions import TypedDict
from typing import Annotated, Literal
import random

from langchain_core.pydantic_v1 import Field
from langgraph.graph import StateGraph

from src.llm.openai_llm import GPTMini
from src.prompt.prompt_sorting import MethodPrompt, SortingPrompt, GetRationaleProbPrompt, AggregateRationalePrompt
from src.util.reducers import *
from src.graph.base_sorting_graph import get_method, get_answering, get_rationale_prob, get_aggregate_rationales, get_aggregate_graph


class BaseState(TypedDict):
    data: Annotated[list[int], fixed_value]
    node_list: Annotated[list[int], update_list]
    node_graph: Annotated[list[list[int]], update_list]


class UpdateInfoState(BaseState):
    not_update_number: Annotated[dict, update_dict]


class TemperatureState(UpdateInfoState):
    temperature: Annotated[dict, update_dict] = Field(gt=0.0, le=1.0)


class SortingState(TemperatureState):
    method: Annotated[str, fixed_value]
    evaluate_methods: Annotated[list, fixed_value]
    answers: Annotated[dict, update_dict]
    answer_rationales: Annotated[dict, update_dict]
    answer_accuracies: Annotated[dict, update_dict] = Field(gt=0, le=1)
    answer_probs: Annotated[dict, update_dict] = Field(gt=0, le=1)
    answer_accuracy_rationales: Annotated[dict, update_dict]
    aggregate_rationales: Annotated[dict, update_dict]
    inaccurate_infos: Annotated[dict, update_dict]
    update_success: Annotated[dict, update_dict]

    node_name: Annotated[int, fixed_value]
    end_state: Annotated[bool, fixed_value]

    for_test_acc: Annotated[dict, update_dict]


def sorting_graph():
    t_max = 1
    make_graph = [[0, 0],         [7, 7],             [14, 14],
                  [0, 1], [0, 2], [7, 8],   [7, 9],   [14, 15], [14, 16],
                  [1, 3], [2, 4], [8, 10],  [9, 11],  [15, 17], [16, 18],
                  ]
    leaf_node = [3, 4, 10, 11, 17, 18]


    def init(state: SortingState):
        return {"end_state": False}

    def set_node(node_name, parent_node_name):
        def _set_node(state: SortingState):
            new_state = dict()
            new_state["node_name"] = node_name
            new_state["node_list"] = [node_name]
            new_state["node_graph"] = [[parent_node_name, node_name]]
            return new_state
        return _set_node

    def fix_state(node_name, parent_node_name):
        def get_fix_state(state: SortingState):
            new_state = dict()
            # SGDR: Stochastic Gradient Descent with Warm Restarts
            epoch = len(state["node_list"])
            max_epoch = len(make_graph)
            max_temperature = 0.7
            min_temperature = 1 - np.power(1 - np.power(state["answer_probs"][node_name]-1, 2), 0.5)
            temperature = min_temperature + (max_temperature - min_temperature) * (1 + math.cos(math.pi * epoch / max_epoch)) / 2
            new_state["temperature"] = {node_name: temperature}
            new_state["update_success"] = {node_name: True}

            sorted_data = sorted(state["data"])
            intersection = list((Counter(sorted_data) & Counter(state['answers'][node_name])).elements())
            union = list((Counter(sorted_data) | Counter(state['answers'][node_name])).elements())
            print(f"acc: {100 * len(intersection) / len(union)}")
            new_state["for_test_acc"] = {node_name: 100 * len(intersection) / len(union)}

            return new_state
        return get_fix_state

    def end_condition(node_name):
        def get_end_condition(state: SortingState) -> Literal[f"fix_state{node_name}", "end"]:
            return f"fix_state{node_name}"
        return get_end_condition

    def end_graph(state: SortingState):
        new_state = dict()
        new_state["end_state"] = True
        return new_state

    input_variables = ["aggregate_rationales"]

    sorting_builder = StateGraph(SortingState)
    sorting_builder.add_node("init", init)
    sorting_builder.set_entry_point("init")
    sorting_builder.add_node("get_method", get_method(MethodPrompt, GPTMini))

    def add_node(builder, node_name, parent_node_name):
        builder.add_node(f"set_node{node_name}", set_node(node_name, parent_node_name))
        builder.add_node(f"answer{node_name}", get_answering(node_name, parent_node_name, SortingPrompt, GPTMini, input_variables))
        builder.add_node(f"get_rationale_prob{node_name}", get_rationale_prob(node_name, parent_node_name, GetRationaleProbPrompt, GPTMini))
        builder.add_node(f"aggregate_rationale{node_name}", get_aggregate_rationales(node_name, parent_node_name, AggregateRationalePrompt, GPTMini))
        builder.add_node(f"fix_state{node_name}", fix_state(node_name, parent_node_name))
        return builder

    for parent_node, node in make_graph:
        sorting_builder = add_node(sorting_builder, node, parent_node)

    sorting_builder.add_node("end", end_graph)

    def make_depth(builder, depth=5):
        for i in range(depth+1):
            builder.add_node(f"depth{i}", lambda state: {"node_name": f"depth{i}"})
        return builder

    sorting_builder = make_depth(sorting_builder, t_max * 3)
    sorting_builder.add_node("aggregate", get_aggregate_graph(SortingPrompt, GPTMini, leaf_node))
    sorting_builder.set_finish_point("aggregate")

    sorting_builder.add_edge("init", "get_method")
    def init_add_edge(builder, node_name):
        builder.add_edge("get_method", f"set_node{node_name}")
        builder.add_edge(f"set_node{node_name}", f"answer{node_name}")
        builder.add_edge(f"answer{node_name}", f"get_rationale_prob{node_name}")
        builder.add_edge(f"get_rationale_prob{node_name}", f"aggregate_rationale{node_name}")
        builder.add_conditional_edges(f"aggregate_rationale{node_name}", end_condition(node_name))
        return builder

    def add_edge(builder, node_name, parent_node_name, end_state: bool = False):
        builder.add_edge(f"fix_state{parent_node_name}", f"set_node{node_name}")
        builder.add_edge(f"set_node{node_name}", f"answer{node_name}")
        builder.add_edge(f"answer{node_name}", f"get_rationale_prob{node_name}")
        builder.add_edge(f"get_rationale_prob{node_name}", f"aggregate_rationale{node_name}")
        builder.add_conditional_edges(f"aggregate_rationale{node_name}", end_condition(node_name))
        if end_state:
            builder.add_edge(f"fix_state{node_name}", "end")
        return builder

    for parent_node, node in make_graph:
        if node == parent_node:
            sorting_builder = init_add_edge(sorting_builder, node)
            continue
        end_state = False
        if node in leaf_node:
            end_state = True
        sorting_builder = add_edge(sorting_builder, node, parent_node, end_state)

    sorting_builder.add_edge("end", "depth0")
    for i in range(0, t_max*3):
        sorting_builder.add_edge(f"depth{i}", f"depth{i+1}")

    sorting_builder.add_edge(f"depth{t_max*3}", "aggregate")
    return sorting_builder.compile()
