import yaml
import asyncio
from typing import List, Union, Tuple

import streamlit as st
import networkx as nx
import pydot

from ...utils.wrappers import LLM, LLMBaseModel, LLMField, BaseModel
from ...utils.llms import build_json_agent, LoggingCallback, LLMLog
from ...utils.functions import type_cmd, list_to_bullet_points, add_code_fences
from ...utils.schemas import File


SYS_DESCRIBE_INTER_DEPENDENCY = """\
You are a professional kubernetes (k8s) engineer.
Given two k8s manifests and the dependencies between them, please describe the dependencies according to the following rules:
- The dependencies are given as the DOT format generated by kubectl-graph.
- Using the manifests and their summaries as references, describe the given dependencies in natural language for clarity.
- {format_instructions}"""

USER_DESCRIBE_INTER_DEPENDENCY = """\
# Source K8s manifest
{src_k8s_yaml}

# Summary of the source manifest:
{src_k8s_summary}

# Destination K8s manifest
{src_k8s_yaml}

# Summary of the destination manifest:
{src_k8s_summary}

# The dependencies between the source and destination manifests:
{inter_dependency}

Please describe the dependencies between the source and destination manifests."""

SYS_DESCRIBE_INTRA_DEPENDENCY = """\
You are a professional kubernetes (k8s) engineer.
Given a k8s manifest and the dependencies within it, please describe the dependencies according to the following rules:
- The dependencies are given as the DOT format generated by kubectl-graph.
- Using the manifest and its summary as references, describe the given dependencies in natural language for clarity.
- {format_instructions}"""

USER_DESCRIBE_INTRA_DEPENDENCY = """\
# K8s manifest:
{k8s_yaml}

# Summary of the above manifest:
{k8s_summary}

# The dependency within the above manifest:
{intra_dependency}

Please describe the intra dependency within the manifest."""


class InterDependency(BaseModel):
    src_file: str
    dst_file: str
    dependency: Union[List[str], str]

class IntraDependency(BaseModel):
    file: str
    dependency: Union[List[str], str]

class K8sDependencies(BaseModel):
    intra: List[IntraDependency]
    inter: List[InterDependency]

class K8sDependencyDesciption(LLMBaseModel):
    k8s_dependency: str = LLMField(description="Descripiton of the dependency between two k8s resources")


class K8sAnalysisAgent:
    def __init__(self, llm: LLM) -> None:
        self.llm = llm
        self.inter_agent = build_json_agent(
            llm=llm,
            chat_messages=[("system", SYS_DESCRIBE_INTER_DEPENDENCY), ("human", USER_DESCRIBE_INTER_DEPENDENCY)],
            pydantic_object=K8sDependencyDesciption,
            is_async=False
        )
        self.intra_agent = build_json_agent(
            llm=llm,
            chat_messages=[("system", SYS_DESCRIBE_INTRA_DEPENDENCY), ("human", USER_DESCRIBE_INTRA_DEPENDENCY)],
            pydantic_object=K8sDependencyDesciption,
            is_async=False
        )

    def analyze_manifest_dependencies(
        self,
        k8s_yamls: List[File],
        k8s_summaries: List[str],
        work_dir: str,
        project_name: str
    ) -> Tuple[LLMLog, K8sDependencies]:
        self.logger = LoggingCallback(name="k8s_dependency", llm=self.llm)
        # analyze dependencies 
        k8s_dependencies = self.analyze_k8s_yaml_dependencies(
            k8s_yamls,
            work_dir,
            project_name
        )

        # add description of each dependency
        dependency_description = self.describe_dependencies(
            k8s_yamls,
            k8s_summaries,
            k8s_dependencies,
        )
        return self.logger.log, dependency_description

    def analyze_k8s_yaml_dependencies(
        self,
        k8s_yamls: List[File],
        work_dir: str,
        project_name: str
    ) -> K8sDependencies:
        resource_to_yaml = {}
        for k8s_yaml in k8s_yamls:
            docs = yaml.safe_load_all(k8s_yaml.content)
            for doc in docs:
                if doc and 'metadata' in doc and 'name' in doc['metadata']:
                    resource_name = doc['metadata']['name']
                    resource_to_yaml[resource_name] = k8s_yaml.fname

        # get dependencies between resources
        graph_path = f"{work_dir}/inputs/dependency.dot"
        res = type_cmd(f"kubectl graph all --selector=project={project_name} -t 1000 > {graph_path}")
        (graph,) = pydot.graph_from_dot_file(graph_path)
        G = nx.nx_pydot.from_pydot(graph)
        intra_dependencies = []
        inter_dependencies = []
        for src, dst, data in G.edges(data=True):
            src_name = G.nodes[src]['label'][1:-1] # remove the surrounding double quotes
            dst_name = G.nodes[dst]['label'][1:-1] # remove the surrounding double quotes
            src_yaml = resource_to_yaml.get(src_name)
            dst_yaml = resource_to_yaml.get(dst_name)
            if src_yaml and dst_yaml:
                dependency = data["labeltooltip"][1:-1]
                if src_yaml == dst_yaml: # intra-file dependency
                    internal_entry = next((entry for entry in intra_dependencies if entry.file == src_yaml), None)
                    if internal_entry:
                        internal_entry.dependency.append(dependency)
                    else:
                        intra_dependencies.append(IntraDependency(file=src_yaml, dependency=[dependency]))
                else: # inter-file dependency
                    external_entry = next((entry for entry in inter_dependencies if entry.src_file == src_yaml and entry.dst_file == dst_yaml), None)
                    if external_entry:
                        external_entry.dependency.append(dependency)
                    else:
                        inter_dependencies.append(InterDependency(src_file=src_yaml, dst_file=dst_yaml, dependency=[dependency]))
        return K8sDependencies(intra=intra_dependencies, inter=inter_dependencies)
    
    def describe_dependencies(
        self,
        k8s_yamls: List[File],
        k8s_summaries: List[str],
        k8s_dependencies: K8sDependencies
    ) -> K8sDependencies:
        #--------------------
        # intra dependencies
        #--------------------
        intra_results = []
        for intra_dependency in k8s_dependencies.intra:
            index = self.get_index(k8s_yamls, intra_dependency.file)
            intra_dependency_ = self.describe_intra_dependency(
                k8s_yaml=k8s_yamls[index],
                k8s_summary=k8s_summaries[index],
                intra_dependency=intra_dependency.dependency
            )
            intra_results.append(intra_dependency_)

        #--------------------
        # inter dependencies
        #--------------------
        inter_results = []
        for inter_dependency in k8s_dependencies.inter:
            src_index = self.get_index(k8s_yamls, inter_dependency.src_file)
            dst_index = self.get_index(k8s_yamls, inter_dependency.dst_file)
            inter_dependency_ = self.describe_inter_dependency(
                src_k8s_yaml=k8s_yamls[src_index],
                src_k8s_summary=k8s_summaries[src_index],
                dst_k8s_yaml=k8s_yamls[dst_index],
                dst_k8s_summary=k8s_summaries[dst_index],
                inter_dependency=inter_dependency.dependency
            )
            inter_results.append(inter_dependency_)
        return K8sDependencies(intra=intra_results, inter=inter_results)

    def describe_intra_dependency(
        self,
        k8s_yaml: File,
        k8s_summary: str,
        intra_dependency: List[str]
    ) -> str:
        st.write(f"Intra dependencies in ```{k8s_yaml.fname}```")
        container = st.empty()
        for summary in self.intra_agent.stream({
            "k8s_yaml": add_code_fences(k8s_yaml.content, k8s_yaml.fname),
            "k8s_summary": k8s_summary,
            "intra_dependency": list_to_bullet_points(intra_dependency)},
            {"callbacks": [self.logger]}
        ):
            if (summary_str := summary.get("k8s_dependency")) is not None:
                container.write(summary_str)
        return IntraDependency(file=k8s_yaml.fname, dependency=summary_str)

    def describe_inter_dependency(
        self,
        src_k8s_yaml: File,
        src_k8s_summary: str,
        dst_k8s_yaml: File,
        dst_k8s_summary: str,
        inter_dependency: List[str]
    ):
        st.write(f"Inter dependencies: ```{src_k8s_yaml.fname}``` ➡ ```{dst_k8s_yaml.fname}```")
        container = st.empty()
        for summary in self.inter_agent.stream({
            "src_k8s_yaml": add_code_fences(src_k8s_yaml.content, src_k8s_yaml.fname),
            "src_k8s_summary": src_k8s_summary,
            "dst_k8s_yaml": add_code_fences(dst_k8s_yaml.content, dst_k8s_yaml.fname),
            "dst_k8s_summary": dst_k8s_summary,
            "inter_dependency": list_to_bullet_points(inter_dependency)},
            {"callbacks": [self.logger]}
        ):
            if (summary_str := summary.get("k8s_dependency")) is not None:
                container.write(summary_str)
        return InterDependency(
            src_file=src_k8s_yaml.fname,
            dst_file=dst_k8s_yaml.fname,
            dependency=summary_str
        )
    
    def get_index(
        self,
        files: List[File],
        file_path: str
    ) -> int:
        try:
            index = next(i for i, file in enumerate(files) if file.fname == file_path)
        except StopIteration:
            raise ValueError(f"File with path '{file_path}' not found.")
        return index