from collections import deque
from typing import List, Tuple, Optional
import pandas as pd
from itertools import combinations

FRUITS = {
    "2021": [
        "apple",
        "avocado",
        "grape",
        "grapefruit",
        "lemon",
        "peach",
        "pear",
    ],
}

AGNOSTIC_STATES = [
    "climate condition",
    "supply chain disruptions",
    "economic health",
    "market sentiment and investor psychology",
    "political events and government policies",
    "natural disasters and other 'black swan' events",
    "geopolitical issues",
]

FRUIT_STATES = {
    "2021": {
        # product-agnostic state variables
        "agnostic": {
            'gpt-4': {
                "climate condition": "the climate condition of the next agricultural season in California",
                "supply chain disruptions": "the supply chain disruptions of the next agricultural season in California",
            },
            'claude-3': {
                "climate condition": "the climate condition (e.g. drought, heat waves, winter freezes) of the next agricultural season in California",
                "disease and pests": "potential disease and pests (e.g. citrus greening, fruit drop) of the next agricultural season in California",
                "supply chain disruptions": "the supply chain disruptions of the next agricultural season in California",
            },
            'gemini': {
                "climate condition": "the climate condition (e.g. drought, heat waves, winter freezes) of the next agricultural season in California",
                "demand": "the domestic and international demand of the next agricultural season in California",
                "imports": "competition of imports from foreign suppliers",
            },
        },
        
        # product-specific state variables
        "specific": {
            # 'demand change': 'the demand change of the next agricultural season in California',
            "price change": lambda c: f"the change in price per unit of {c} for the next agricultural season in California",
            "yield change": lambda c: f"the change in yield of {c} for the next agricultural season in California",
        },
    },
}

STOCKS = ["AMD", "DIS", "GME", "GOOGL", "META", "NVDA", "SPY"]
STOCKS_SYMBOL_TO_NAME_MAP = {
    "AMD": "Advanced Micro Devices",
    "DIS": "The Walt Disney Company",
    "GME": "GameStop Corp",
    "GOOGL": "Alphabet, i.e. Google",
    "META": "Meta Platforms, i.e. Facebook",
    "NVDA": "NVIDIA",
    "SPY": "S&P 500",
}


def get_combinations(
    agent_name: str, source_year: Optional[str] = None
) -> List[Tuple[str, ...]]:
    combs = []
    if agent_name == "farmer":
        products = FRUITS[source_year]
    elif agent_name == "trader":
        products = STOCKS
    else:
        raise ValueError("agent_name must be either 'farmer' or 'trader'")

    for i in range(2, len(products) + 1):
        for c in combinations(products, i):
            combs.append(c)

    return combs


def clean_ag_data(
    fname: str,
    sep: str = "\t",
    keep_cols: List[str] = [
        "Commodity",
        "Yield",
        "Price per Unit",
    ],
) -> pd.DataFrame:
    with open(fname) as f:
        print(fname)
        df = pd.read_csv(f, sep=sep)
        df = df[keep_cols].dropna()
        df = df[~df["Commodity"].str.contains(",")]
        df = df[~df["Yield"].str.contains("(D)")]
        df = df[~df["Price per Unit"].str.contains("(D)")]
        df = df.reset_index(drop=True)
        df.to_csv(fname, index=False, sep=sep)
    return df


def clean_ag_reports(
    read_fname: str,
    write_fname: str,
) -> None:
    with open(read_fname, "r") as rf, open(write_fname, "w") as wf:
        paragraph = []
        for line in rf:
            line = line.strip()
            if len(line) == 0:
                paragraph = " ".join(paragraph) + "\n\n"
                wf.write(paragraph)
                paragraph = []
            else:
                paragraph.append(line)


def merge_by_commodity(
    df_x: pd.DataFrame | str,
    df_y: pd.DataFrame | str,
    on: str = "Commodity",
) -> pd.DataFrame:
    if type(df_x) == str:
        df_x = pd.read_csv(df_x)
    if type(df_y) == str:
        df_y = pd.read_csv(df_y)
    df = pd.merge(df_x, df_y, on=on)
    return df


class SectionNode:
    def __init__(
        self,
        title: str,
        passage: str = "",
        parent: Optional["SectionNode"] = None,
        depth: int = 0,
    ):
        self.title = title
        self.passage = passage
        self.children = []
        self.parent = parent
        self.depth = depth

    def add_child(self, child: "SectionNode"):
        self.children.append(child)

    def __repr__(self) -> str:
        return f"Node: {self.title} with {len(self.children)} children"

    def format(self, denote_section: bool = False) -> str:
        if denote_section:
            prefix = "#" * self.depth + " "
        else:
            prefix = ""
        return prefix + f"{self.title}\n{self.passage}\n"


class SectionTree:
    def __init__(self):
        self.root = SectionNode("[ROOT]", "", None, 0)

    def get_ancestor_content(
        self,
        node: SectionNode,
        res: str,
        denote_section: bool = False,
    ) -> str:
        """Add ancestor content if it has not been added"""
        # get ancestor queue
        anc_queue = deque()
        curr_node = node.parent
        if curr_node is None:
            return res
        while curr_node.title != "[ROOT]":
            anc_queue.append(curr_node)
            curr_node = curr_node.parent
        # add ancestor content
        while len(anc_queue) > 0:
            anc_node = anc_queue.pop()
            anc_content = anc_node.format(denote_section)
            if anc_content.lower() not in res.lower():
                res += anc_node.format(denote_section)
        return res

    def dfs(
        self,
        node: Optional[SectionNode] = None,
        denote_section: bool = False,
        keywords: Optional[List[str]] = None,
    ) -> str:
        if node is None:
            node = self.root
        # TODO: implement force add
        res = ""
        content = node.format(denote_section)
        if keywords is None and node.title != "[ROOT]":
            res += content
        elif keywords is not None:
            for kwd in keywords:
                if kwd in content and node.title != "[ROOT]":
                    # add ancesor content if it has not been added
                    res = self.get_ancestor_content(node, res, denote_section)
                    res += content
                    break
        for child in node.children:
            res += self.dfs(child, denote_section, keywords)
        return res


def check_title(line: str, max_depth: int = 3) -> Tuple[bool, int]:
    for d in range(1, max_depth + 1):
        if line.startswith("#" * d + " "):
            return True, d
    return False, -1


def treeify(fname: str) -> SectionTree:
    with open(fname, "r") as f:
        lines = f.readlines()
        # filter out newlines
        lines = [line.strip() for line in lines if len(line.strip()) > 0]

    tree = SectionTree()
    curr_node = tree.root
    for line in lines:
        is_title, depth = check_title(line)
        if is_title:
            while depth <= curr_node.depth:
                curr_node = curr_node.parent
            # add new node
            title = line[depth + 1 :]
            new_node = SectionNode(title, "", curr_node, depth)
            curr_node.add_child(new_node)
            curr_node = new_node
        else:
            curr_node.passage += line + "\n"
    return tree
