from sentence_transformers import SentenceTransformer
import pickle
import faiss
import numpy as np
import re
import json
from pathlib import Path
import networkx as nx
import matplotlib.pyplot as plt
import openai 
from pathlib import Path
scenario_NO = '17'
from datetime import datetime
EMBED_MODEL_NAME = "all-MiniLM-L6-v2"
embedder = SentenceTransformer(EMBED_MODEL_NAME, device="cuda")

openai.api_key = ""

MODEL = "gpt-4.1-mini"
MAX_ITERS = 5
scenario_NO = '17'

def embed_query(query):
    """
    Returns a 1D numpy array embedding for a query string
    """
    emb = embedder.encode(
        query,
        normalize_embeddings=True
    )
    return emb

def load_vector_store(path_prefix="./brauer_ch2"):
    index = faiss.read_index(f"{path_prefix}.index")
    with open(f"{path_prefix}_chunks.pkl", "rb") as f:
        chunks = pickle.load(f)
    return index, chunks



def retrieve_chunks(query, k=3):
    index, chunks = load_vector_store()

    q_emb = embed_query(query)
    q_emb = np.array([q_emb]).astype("float32")

    distances, indices = index.search(q_emb, k)

    results = []
    for rank, idx in enumerate(indices[0]):
        results.append({
            "rank": rank + 1,
            "distance": float(distances[0][rank]),
            "text": chunks[idx]
        })

    return results


def return_knowledge(query):
    results = retrieve_chunks(
        query=query,
        k=1
    )
    
    related_knowledge = ""
    for r in results:
        related_knowledge += r["text"]['text'] 
    return related_knowledge


def get_system_description(
    env_name = 'Covid-scenario',
    base_dir="/project/biocomplexity/MIDAS/SMH_summary/scenario_modeling_hub-report",
    round_name= f"round{scenario_NO}",
    disease="Covid-19",
    filename=f"scen_r{scenario_NO}.md",
):
    if env_name == 'Covid-scenario':
        md_path = Path(base_dir) / disease / round_name / filename
    
        if not md_path.exists():
            raise FileNotFoundError(f"Scenario file not found: {md_path}")
    
        with open(md_path, "r", encoding="utf-8") as f:
            scenario_text = f.read()
    
    
        return f"""
================ SCENARIO SPECIFICATION ================

{scenario_text}

========================================================
"""
    elif env_name == 'PNAS':
        return f"""
================  BEHAVIORAL EPIDEMIC BASELINE REPLICATION ================
            1)  ***DDB (Data-driven behavior; exogenous)***
            - Compute mobility reduction r(t) from Google mobility.
            - Exclude parks from the averaged mobility signal.
            - Apply: lambda'(k,t) = r(t) * lambda(k,t).
            - For forecasting beyond observed mobility window, use last observed r(t) (status quo).
            
            2) ***CBF (Compartmental behavior feedback; endogenous)***
            - Add a behavior susceptible compartment S_B[k] (risk-averse susceptible).
            - Individuals in S_B[k] have reduced infection pressure by factor r_B in (0,1).
            - Transition S -> S_B increases with recent reported deaths signal D(t-1) via a saturating function.
            - Transition S_B -> S relaxes as perceived risk decreases (e.g., function of (S+R)/N or explicit relaxation rate).

            
            3) ***EFB (Effective force-of-infection damping; endogenous implicit)***
            - Do NOT add new compartments.
            - Multiply lambda(k,t) by f(D(t-1), cumulative deaths memory), where f decreases with deaths and saturates.
            - Include a short-term response to recent deaths and a long-term memory component.
========================================================
"""
    
        
        
        
    

def build_prompt(env_name, scenario_text, feedback=None):
    if env_name == "Covid-scenario":
        knowledge = return_knowledge(scenario_text)
    else:
        knowledge = None
    base_prompt = f"""
You must generate a compartmental disease model as a STRING ONLY.

Scenario description:
{scenario_text}


Related Knowledge:
{knowledge}


States (not fixed, can add or remove):
*** DO NOT ADD ANY STATE THAT IS NOT NEEDED ACCORDING TO THE SCENARIO ***

S = Susceptible
E = Exposed
I = Infectious
R = Recovered
V = Vaccinated
D = Dead

Allowed transitions (not fixed, can add or remove):
*** DO NOT ADD ANY TRANSITION THAT IS NOT NEEDED ACCORDING TO THE SCENARIO ***

S -> E
E -> I
....



Rules:
- Output format MUST be:

NODES:
....

EDGES:
E -> I [beta] #infection rate (Ref: ?)
I -> R [gamma] #recovery rate (Ref: ?)
...

SCENARIO A:
NODES:
...

EDGES:
...

SCENARIO B:
NODES:
...

EDGES:
...

.....
No explanations. No prose.
"""

    if feedback:
        base_prompt += f"\n\nCorrection feedback:\n{feedback}\nFix all issues."

    return base_prompt.strip()
    


def call_llm(prompt):
    response = openai.ChatCompletion.create(
        model=MODEL,
        messages=[{"role": "user", "content": prompt}],
        temperature=0.5
        )
    return response.choices[0].message.content


def parse_graph(text):
    edges = []
    for line in text.splitlines():
        match = re.match(
            r"([A-Za-z0-9+<>_]+)\s*->\s*([A-Za-z0-9+<>]+)",
            line.strip()
        )
        if match:
            edges.append((match.group(1), match.group(2)))
    return edges



def llm_verify_graph(graph_text, scenario_text):
    """
    Uses LLM to verify if the graph makes sense for the scenario.
    Returns (is_valid: bool, explanation: str, feedback: str)
    """
    verification_prompt = f"""
You are an expert in epidemiological modeling. Analyze the following compartmental disease model graph.

SCENARIO:
{scenario_text}

GENERATED GRAPH:
{graph_text}

Please provide:
1. **Reasoning**: Explain the purpose of each state (node) and why it exists based on the scenario
2. **Transition Analysis**: Explain each transition (edge) and whether it makes epidemiological sense
3. **Scenario Alignment**: Does this graph accurately represent the scenario requirements?
4. **Issues**: List any problems, missing states/transitions, or unnecessary elements
5. **Verdict**: State clearly "VALID" or "INVALID"

If INVALID, provide specific feedback on what needs to be fixed.

Format your response as:
REASONING:
...

TRANSITIONS:
...

ALIGNMENT:
...

ISSUES:
...

VERDICT: [VALID/INVALID]

FEEDBACK (if invalid):
...
"""
    
    response = openai.ChatCompletion.create(
        model=MODEL,
        messages=[{"role": "user", "content": verification_prompt}],
        temperature=0.3
    )
    
    verification_result = response.choices[0].message.content
    
    # Parse the verdict
    is_valid = "VERDICT: VALID" in verification_result
    
    # Extract feedback if invalid
    feedback = ""
    if not is_valid and "FEEDBACK" in verification_result:
        feedback_section = verification_result.split("FEEDBACK")[1]
        feedback = feedback_section.split("VERDICT:")[0].strip() if "VERDICT:" in feedback_section else feedback_section.strip()
    
    return is_valid, verification_result, feedback


def generate_valid_graph(env_name, scenario_text):
    feedback = None
    
    for i in range(MAX_ITERS):
        print(f"\n{'='*60}")
        print(f"--- Iteration {i+1} ---")
        print(f"{'='*60}")
        
        # Generate graph
        prompt = build_prompt(env_name, scenario_text, feedback)
        graph_text = call_llm(prompt)
        
        print("\nGenerated graph:\n")
        print(graph_text)
        
        # Rule-based verification
        edges = parse_graph(graph_text)
        rule_errors = verify_graph(edges, scenario_text)
        
        if rule_errors:
            print("\n❌ Rule-based errors detected:")
            for e in rule_errors:
                print(f"   - {e}")
            feedback = "Fix the following issues:\n" + "\n".join(rule_errors)
            continue
        
        print("\n✅ Passed rule-based validation")
        
        # LLM-based verification
        print("\n🤖 Running LLM verification...")
        is_valid, explanation, llm_feedback = llm_verify_graph(graph_text, scenario_text)
        
        print("\n" + "="*60)
        print("LLM VERIFICATION RESULT:")
        print("="*60)
        print(explanation)
        print("="*60)
        
        if is_valid:
            print(f"\ GRAPH ACCEPTED at iteration {i+1}")
            
            # Save the verification report
            ts = datetime.now().strftime("%Y%m%d_%H%M%S")
            report_fname = f'Round{scenario_NO}_verification_report_{ts}.txt'
            with open(report_fname, "w", encoding="utf-8") as f:
                f.write(f"Iteration: {i+1}\n\n")
                f.write(f"SCENARIO:\n{scenario_text}\n\n")
                f.write(f"GENERATED GRAPH:\n{graph_text}\n\n")
                f.write(f"VERIFICATION:\n{explanation}\n")
            
            return graph_text
        else:
            print("\n❌ LLM found issues with the graph")
            feedback = llm_feedback if llm_feedback else "The graph does not align well with the scenario. Please regenerate."
            print(f"\nFeedback for next iteration:\n{feedback}")
    
    raise RuntimeError(
        f"Failed to generate valid graph after {MAX_ITERS} iterations.\n"
        f"Last feedback:\n{feedback}"
    )


    
def verify_graph(edges, scenario_constraints):
    errors = []

    edge_set = set(edges)

    if ("S", "R") in edge_set:
        errors.append("Invalid transition: S -> R (recovery without infection)")

    if ("S", "I") in edge_set:
        errors.append("Invalid transition: S -> I (missing exposed state)")

    for src, dst in edge_set:
        if dst.startswith("D") and not src.startswith(("I","J","H")):
            errors.append(f"Invalid death transition: {src} -> D")
        elif src.startswith("D"):
            errors.append(f"Invalid death transition:  D can not be in source")
        elif dst.startswith("V") and not src.startswith(("S","R","W")):
            errors.append(f"Invalid vaccination transition: {src}-> V")
        elif dst.startswith("W") and not src.startswith(("R","V")):
            errors.append(f"Invalid waning transition: {src} -> {dst}")
        elif dst.startswith("E") and not src.startswith(("S","W", "V")):
            errors.append(f"Invalid Exposed transition: {src} -> {dst}")
    

    return errors

def build_graph(nodes, edges):
    G = nx.DiGraph()

    for n in nodes:
        G.add_node(n)

    for e in edges:
        G.add_edge(e["src"], e["dst"], label=e["rate"])

    return G
    
def parse_scenario_block(text):

    scenarios = {}
    current = "GLOBAL"
    scenarios[current] = {"nodes": [], "edges": []}

    for line in text.splitlines():
        line = line.strip()
        if not line:
            continue

        # New scenario
        if line.startswith("SCENARIO"):
            current = line.replace(":", "")
            scenarios[current] = {"nodes": [], "edges": []}
            continue

        # Node line (single token)
        if re.fullmatch(r"[A-Za-z0-9+<>_]+", line):
            scenarios[current]["nodes"].append(line)
            continue

        # Edge line
        edge_match = re.match(
            r"([A-Za-z0-9+<>_]+)\s*->\s*([A-Za-z0-9+<>_]+)\s*\[(.*?)\]",
            line
        )
        if edge_match:
            src, dst, rate = edge_match.groups()
            scenarios[current]["edges"].append(
                {"src": src, "dst": dst, "rate": rate}
            )

    return scenarios

    
def plot_all_scenarios_side_by_side(scenarios):
    scenario_items = [
        (name, data)
        for name, data in scenarios.items()
        if name != "GLOBAL"
    ]

    n = len(scenario_items)
    cols = 3
    rows = (n + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(18, 6 * rows))
    axes = axes.flatten()

    for ax, (scenario_name, data) in zip(axes, scenario_items):
        G = build_graph(data["nodes"], data["edges"])
        pos = nx.spring_layout(G, seed=1)

        nx.draw(
            G, pos,
            ax=ax,
            with_labels=True,
            node_size=1600,
            node_color="lightblue",
            edgecolors="black",
            font_size=9,
            arrowsize=15
        )

        edge_labels = nx.get_edge_attributes(G, "label")
        nx.draw_networkx_edge_labels(
            G, pos,
            edge_labels=edge_labels,
            font_size=7,
            ax=ax
        )

        ax.set_title(scenario_name)
        ax.axis("off")

    # Hide unused subplots
    for ax in axes[len(scenario_items):]:
        ax.axis("off")

    plt.tight_layout()
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    plt.savefig(f'Round{scenario_NO}_with_knowledge_finalwith_WITHKG_4o_{ts}', dpi=300, bbox_inches="tight")
    plt.show()



def get_graph(env_name):
    """Modified to handle the verification loop"""
    try:
        scenario_text = get_system_description(env_name)
        graph = generate_valid_graph(env_name, scenario_text)
        scenarios = parse_scenario_block(graph)
        plot_all_scenarios_side_by_side(scenarios)
        
        return graph, None  # Return graph and no error
        
    except Exception as e:
        error_msg = str(e)
        print(f"\n❌ ERROR: {error_msg}")
        
        ts = datetime.now().strftime("%Y%m%d_%H%M%S")
        error_fname = f'Round{scenario_NO}_error4o_{ts}.txt'
        with open(error_fname, "w", encoding="utf-8") as f:
            f.write(error_msg)
        
        return None, error_msg


def prompt_graph(env_name):
    successful_graphs = []
    
    for attempt in range(2):
        print(f"\n\n{'#'*70}")
        print(f"# ATTEMPT {attempt + 1}/10")
        print(f"{'#'*70}\n")
        
        graph, error = get_graph(env_name)
        
        if graph is not None:
            ts = datetime.now().strftime("%Y%m%d_%H%M%S")
            base_fname = f'Round{scenario_NO}_VALIDATED_WITHKG4o_{ts}'
            graph_fname = f"{base_fname}.txt"
            
            with open(graph_fname, "w", encoding="utf-8") as f:
                f.write(graph)
            
            print(f"\n Successfully generated and validated graph!")
            return graph
        else:
            print('NO GRAPH GENENARATED')


    

# for _ in range(10):
#     g = get_graph()
#     ts = datetime.now().strftime("%Y%m%d_%H%M%S")
#     base_fname = f'Round{scenario_NO}_with_WITHKG4o_iter_{ts}'
#     graph_fname = f"{base_fname}.txt"
#     with open(graph_fname, "w", encoding="utf-8") as f:
#         f.write(g)
