import os
import re
import csv
import math
import random
import argparse
from datetime import datetime
from typing import List, Tuple

from code_gen.spec_gen_with_skill import spec_gen_with_skill_save
from code_gen.code_gen_with_skill import (
    code_gen_with_skill_save,
    realworld_code_gen_with_skill_save,
)
from code_gen.exploration_gen_with_skill import rlbench_exploration_gen_with_skill_save
from code_gen.plan_gen import solve_pddl
from verification.verification_code import verify_code_with_spec
from validation.total_validation import validate_skill

########################
# IO helpers
########################

def read_file(path: str) -> str:
    with open(path, "r") as f:
        return f.read()


def write_file(path: str, text: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w") as f:
        f.write(text)


def write_log(path: str, msg: str):
    with open(path, "a") as f:
        f.write(f"{msg}\n")

########################
# Parsing helpers
########################

def split_observation_and_goal(problem_txt: str) -> str:
    idx = problem_txt.find("(:goal")
    return problem_txt[:idx] if idx != -1 else problem_txt


def get_skill_def_code(skill_name: str, all_code: str) -> str:
    pat = rf"(def\s+{skill_name}\(.*?)(?=def\s+|$)"
    m = re.findall(pat, all_code, flags=re.DOTALL)
    return m[0].strip() if m else ""


def parse_generated_code_for_skills_rlbench(code_text: str):
    pat = r"obs\s*,\s*reward\s*,\s*done\s*=\s*(\w+)\(([^)]*)\)"
    matches = re.findall(pat, code_text)
    return [
        {"skill_name": m[0], "skill_call_code": f"{m[0]}({m[1]})"}
        for m in matches
    ]

########################
# Confidence helpers
########################

def skill_confidence(llm_logprob: float, lc_pass: bool) -> float:
    prob = math.exp(llm_logprob)
    return prob * 100.0 if lc_pass else 0.0


def log_confidence_csv(csv_path: str, iteration: int, skill_rows: List[Tuple[str, float]], total_conf: float, target_skill: str):
    new_file = not os.path.exists(csv_path)
    with open(csv_path, "a", newline="") as fp:
        writer = csv.writer(fp)
        if new_file:
            writer.writerow(["iteration", "skill_name", "skill_confidence", "total_confidence", "target_skill"])
        for name, scon in skill_rows:
            writer.writerow([iteration, name, f"{scon:.2f}", f"{total_conf:.2f}", target_skill])

########################
# Main pipeline
########################

def main():
    parser = argparse.ArgumentParser(description="Type‑3 high‑obs pipeline with snapshots")
    parser.add_argument("--env", nargs="+", required=True)
    parser.add_argument("--type", type=int, choices=[3], required=True)
    args = parser.parse_args()

    log_dir = "Input_Your_Path"
    os.makedirs(log_dir, exist_ok=True)
    log_path = os.path.join(log_dir, datetime.now().strftime("%Y%m%d_%H%M%S_log.txt"))

    write_log(log_path, f"Started envs={args.env}, type=3, obs=high")

    # Constants
    threshold = -0.27
    observation_types = ["high"]
    seed_count = 2

    rlbench_tasks = ["__exploration_long_horizon_1"]
    realworld_tasks = ["move_item_to_drawer"]

    for env in args.env:
        base_path = f"Input_Your_Path/{env}"
        tasks = rlbench_tasks if env == "rlbench" else realworld_tasks
        avail_skills = (
            ["pick", "place", "move", "rotate", "pull"]
            if env == "rlbench" else
            [
                "execute_pick_and_place", "execute_pick", "execute_place", "execute_push",
                "execute_pull", "execute_sweep", "execute_rotate", "execute_go", "execute_gripper",
            ]
        )

        for task_stub in tasks:
            task_name = f"{task_stub}_type3"
            t_dir = os.path.join(base_path, task_name)
            domain_pddl = read_file(os.path.join(t_dir, "pddl", "domain.pddl"))
            expl_knowledge = read_file(os.path.join(t_dir, "pddl", "exploration_knowledge.pddl"))
            skill_code_txt = read_file(os.path.join(t_dir, "code", "skill_code.py"))

            # instructions
            ns = {}
            exec(read_file(os.path.join(t_dir, "instruction.py")), ns)
            instructions = ns.get("instructions", [])
            if not instructions:
                write_log(log_path, f"[WARN] No instructions for {task_name}")
                continue

            for obs_type in observation_types:  # single "high"
                obs_base_txt = split_observation_and_goal(read_file(os.path.join(t_dir, f"init_obs_{obs_type}_1.txt")))
                obj_list_path = os.path.join(t_dir, f"object_list_position_{obs_type}.txt")
                obj_list_txt = read_file(obj_list_path) if os.path.exists(obj_list_path) else ""

                for seed in range(1, seed_count + 1):
                    code_dir = os.path.join(t_dir, "code")
                    os.makedirs(code_dir, exist_ok=True)
                    csv_path = os.path.join(code_dir, f"conf_log_{obs_type}_seed{seed}.csv")

                    # ===========================================
                    # PER‑INSTRUCTION LOOP
                    # ===========================================
                    for inst_idx, instr in enumerate(instructions, start=1):
                        # choose problem pddl (66/34)
                        chosen = 1 if random.random() < 0.66 else 2
                        prob_txt = read_file(os.path.join(t_dir, f"init_obs_{obs_type}_{chosen}.txt"))

                        # ---- spec ----
                        ok, plan = solve_pddl(domain_pddl, prob_txt)
                        spec_gen_with_skill_save(
                            env,
                            task_name,
                            domain_pddl,
                            obs_base_txt,
                            instr,
                            inst_idx,
                            avail_skills,
                            skill_code_txt,
                            obs_type,
                            seed,
                            plan_content=plan if ok else "",
                        )
                        spec_path = os.path.join(t_dir, "spec", f"exe_spec_{obs_type}_seed{seed}_{inst_idx}.json")
                        spec_txt = read_file(spec_path)

                        # ---- CODE SNAPSHOT 0 ----
                        exp_iter = 0
                        code_path = os.path.join(code_dir, f"exe_code_{obs_type}_seed{seed}_{inst_idx}_exploration_{exp_iter}.py")

                        if env == "rlbench":
                            code_txt = code_gen_with_skill_save(
                                task_name, domain_pddl, obs_base_txt, instr, spec_txt, skill_code_txt,
                                avail_skills, obj_list_txt, "none", "none", expl_knowledge, obs_type, seed, 0,
                            )
                        else:
                            code_txt = realworld_code_gen_with_skill_save(
                                task_name, domain_pddl, obs_base_txt, instr, spec_txt, skill_code_txt,
                                avail_skills, obj_list_txt, "none", "none", obs_type, seed, 0,
                            )
                        write_file(code_path, code_txt)

                        # ============== VALIDATION / EXPLORATION LOOP ==============
                        current_obs = obs_base_txt
                        current_prob = prob_txt
                        iteration = 0  # log iteration == exploration index
                        exploration_limit = 5
                        exploration_used = 0

                        while True:
                            generated = read_file(code_path)
                            calls = parse_generated_code_for_skills_rlbench(generated) if env == "rlbench" else []

                            skill_rows: List[Tuple[str, float]] = []
                            target_skill = "‑"
                            exploration_flag = False
                            next_obs = current_obs  # default
                            next_prob = current_prob
                            trigger_def_call = None

                            for c in calls:
                                sname = c["skill_name"]
                                sdef = get_skill_def_code(sname, skill_code_txt)
                                (
                                    _csc_pass, lc_pass, upd_obs, _csc_fb, _lc_fb,
                                    need_expl, llm_log, _conf,
                                ) = validate_skill(
                                    sdef, c["skill_call_code"], domain_pddl, t_dir,
                                    current_obs, sname, [], instr, obj_list_txt, validation_mode="both", threshold=threshold,
                                )
                                skill_rows.append((sname, skill_confidence(llm_log, lc_pass)))

                                current_obs = upd_obs  # propagate for sequential skills

                                if need_expl and not exploration_flag:
                                    exploration_flag = True
                                    target_skill = sname
                                    trigger_def_call = (sdef, c["skill_call_code"])
                                    next_obs = upd_obs

                            total_conf = sum(x[1] for x in skill_rows) / len(skill_rows) if skill_rows else 0.0
                            log_confidence_csv(csv_path, iteration, skill_rows, total_conf, target_skill)

                            # exit if no exploration or limit exceeded
                            if not exploration_flag or exploration_used >= exploration_limit:
                                break

                            # ------------ exploration regens -----------
                            exploration_used += 1
                            exp_iter += 1
                            iteration += 1

                            sdef, scall = trigger_def_call
                            exp_code = rlbench_exploration_gen_with_skill_save(
                                domain_pddl, next_obs, sdef, avail_skills, obj_list_txt,
                                "exploration", expl_knowledge,
                            )
                            exp_dir = os.path.join(t_dir, "exploration")
                            exp_path = os.path.join(exp_dir, f"exploration_{obs_type}_inst{inst_idx}_seed{seed}_{exp_iter}.py")
                            write_file(exp_path, exp_code)

                            # external obs update
                            after_path = os.path.join(t_dir, f"init_obs_after_exploration_{exploration_used}.txt")
                            if os.path.exists(after_path):
                                next_prob = read_file(after_path)
                                next_obs = split_observation_and_goal(next_prob)
                                write_log(log_path, f"[EXP] loaded {after_path}")

                            # re‑spec (plan optional)
                            ok, plan = solve_pddl(domain_pddl, next_prob)
                            spec_gen_with_skill_save(
                                env, task_name, domain_pddl, next_obs, instr, inst_idx, avail_skills,
                                skill_code_txt, obs_type, seed, plan_content=plan if ok else "",
                            )
                            spec_txt = read_file(os.path.join(t_dir, "spec", f"exe_spec_{obs_type}_seed{seed}_{inst_idx}.json"))

                            # re‑code
                            code_path = os.path.join(code_dir, f"exe_code_{obs_type}_seed{seed}_{inst_idx}_exploration_{exp_iter}.py")
                            if env == "rlbench":
                                code_txt = code_gen_with_skill_save(
                                    task_name, domain_pddl, next_obs, instr, spec_txt, skill_code_txt,
                                    avail_skills, obj_list_txt, "", generated, expl_knowledge, obs_type, seed, 0,
                                )
                            else:
                                code_txt = realworld_code_gen_with_skill_save(
                                    task_name, domain_pddl, next_obs, instr, spec_txt, skill_code_txt,
                                    avail_skills, obj_list_txt, "", generated, obs_type, seed, 0,
                                )
                            write_file(code_path, code_txt)
                            current_obs = next_obs
                            current_prob = next_prob

                        # end exploration loop
                        write_log(log_path, f"[DONE] {task_name} inst={inst_idx} seed={seed} obs=high")

    write_log(log_path, "All tasks completed")


if __name__ == "__main__":
    main()
