"""
Main entrypoint for data visualization task disambiguation.
Implements the active disambiguation loop for Ambi-Plot tasks.
"""
import json
import sys
import os

currentdir = os.path.dirname(os.path.abspath(__file__))
parentdir = os.path.dirname(currentdir)
rootdir = os.path.dirname(parentdir)
sys.path.insert(0, parentdir)
sys.path.insert(0, currentdir)

from utils import chat_gpt, llama, save_json, create_directory
import hydra
from omegaconf import DictConfig, OmegaConf

from reasoners import (
    VizReasonerBase,
    ActiveVizReasoner,
    TAIVizReasoner,
    ActiveBinaryVizReasoner,
    TAIBinaryVizReasoner,
    OracleVizAnswerer,
    OracleBinaryVizAnswerer,
)


class VizReasonerHandler:
    """Handler for visualization task disambiguation loop."""
    
    def __init__(self, Reasoner, OracleAnswerer, max_iter, seed, save_path):
        self.Reasoner = Reasoner
        self.OracleAnswerer = OracleAnswerer
        self.max_iter = max_iter
        self.seed = seed
        self.save_path = save_path
        
        self.all_requirements = {}
        self.all_questions = {}
        self.selected_questions = {}
        self.listed_hypothesis = {}
        self.restricted_questions = []
        self.iter = 0

    def run(self):
        """Run the active disambiguation loop."""
        self.Reasoner.total_cost = 0
        self.OracleAnswerer.total_cost = 0

        while True:
            if self.iter == self.max_iter:
                break

            # Step 1. Generate hypothesis
            print(f"\n{'='*50}")
            print(f"Iteration {self.iter}")
            print(f"{'='*50}")
            print("Sampling visualization hypotheses...")
            hypothesis = self.Reasoner.generate_hypothesis(self.all_requirements)
            self.listed_hypothesis[self.iter] = [
                {"content": h.content, "features": h.features} for h in hypothesis
            ]
            print(f"Sampled {len(hypothesis)} hypotheses.")
            
            # Print hypothesis feature summary
            if hypothesis:
                chart_types = set(h.features["chart_type"] for h in hypothesis)
                libraries = set(h.features["library"] for h in hypothesis)
                print(f"  Chart types: {chart_types}")
                print(f"  Libraries: {libraries}")

            # Step 2. Generate questions
            print("\nGenerating clarifying questions...")
            questions = self.Reasoner.generate_questions(
                requirements=self.all_requirements,
                restricted_questions=self.restricted_questions + list(self.selected_questions.values()),
            )
            self.all_questions[self.iter] = questions
            print(f"Candidate questions ({len(questions)}):")
            for i, q in enumerate(questions[:5]):  # Show first 5
                print(f"  {i+1}. {q}")

            # Step 3. Select best question
            print("\nSelecting best question...")
            best_question = self.Reasoner.select_best_question(questions, hypothesis)
            print(f"Selected: {best_question}")
            self.selected_questions[self.iter] = best_question

            # Step 4. Get oracle answer
            answer = self.OracleAnswerer.answer(best_question)
            print(f"Oracle answer: {answer}")

            if answer.lower() in ["error", "unknown", "n/a"]:
                self.restricted_questions.append(best_question)

            self.all_requirements[self.iter] = self.Reasoner.q_a_to_requirement(
                best_question, answer
            )

            self.iter += 1

        # Save results
        save_json(self.all_requirements, f"{self.save_path}/requirements.json")
        save_json(self.all_questions, f"{self.save_path}/questions.json")
        save_json(self.selected_questions, f"{self.save_path}/questions_selected.json")
        save_json(self.listed_hypothesis, f"{self.save_path}/listed_hypothesis.json")

        print(f"\n{'='*50}")
        print(f"Run complete! Total cost: ${self.Reasoner.total_cost + self.OracleAnswerer.total_cost:.4f}")
        print(f"Results saved to: {self.save_path}")

    def evaluate(self):
        """Evaluate hypothesis quality across iterations."""
        self.all_requirements = json.load(open(f"{self.save_path}/requirements.json"))
        evaluation_results = {}
        all_eval_hypothesis = {}
        
        self.Reasoner.total_hypothesis = 20
        self.Reasoner.total_cost = 0
        self.OracleAnswerer.total_cost = 0
        self.Reasoner.mode = "eval"
        self.OracleAnswerer.mode = "eval"

        for seed in range(3):
            all_eval_hypothesis[seed] = {}
            evaluation_results[seed] = {}
            requirements = {}
            
            print(f"\n{'='*50}")
            print(f"Evaluation run {seed}")
            print(f"{'='*50}")
            
            self.Reasoner.seed = seed
            
            # Iteration 0: no requirements
            h_ls = self.Reasoner.generate_hypothesis(requirements)
            results = [self.OracleAnswerer.evaluate_hypothesis(h) for h in h_ls]
            match_rates = [r["match_rate"] for r in results]
            avg_match = sum(match_rates) / len(match_rates) if match_rates else 0
            
            evaluation_results[seed][0] = {
                "avg_match_rate": avg_match,
                "individual_results": results,
            }
            all_eval_hypothesis[seed][0] = [
                {"content": h.content, "features": h.features} for h in h_ls
            ]
            print(f"Iter 0: Avg match rate = {avg_match:.2%}")

            # Subsequent iterations with accumulated requirements
            for iter_key in self.all_requirements.keys():
                iter_num = int(iter_key) + 1
                requirements[str(iter_num)] = self.all_requirements[iter_key]
                
                h_ls = self.Reasoner.generate_hypothesis(requirements)
                results = [self.OracleAnswerer.evaluate_hypothesis(h) for h in h_ls]
                match_rates = [r["match_rate"] for r in results]
                avg_match = sum(match_rates) / len(match_rates) if match_rates else 0
                
                evaluation_results[seed][iter_num] = {
                    "avg_match_rate": avg_match,
                    "individual_results": results,
                }
                all_eval_hypothesis[seed][iter_num] = [
                    {"content": h.content, "features": h.features} for h in h_ls
                ]
                print(f"Iter {iter_num}: Avg match rate = {avg_match:.2%}")

            save_json(all_eval_hypothesis, f"{self.save_path}/eval_hypothesis.json")
            save_json(evaluation_results, f"{self.save_path}/eval_results.json")

        print(f"\n{'='*50}")
        print(f"Evaluation complete! Cost: ${self.Reasoner.total_cost + self.OracleAnswerer.total_cost:.4f}")


@hydra.main(version_base=None, config_path=f"{rootdir}/config", config_name="main_visualization")
def main(cfg: DictConfig) -> None:
    # Load task data
    data = {}
    with open(cfg.dataset_path, "r") as file:
        for line in file:
            content = json.loads(line)
            task_id = content["task_id"]
            content.pop("task_id")
            data[task_id] = content

    task_data = data[cfg.task_id]
    print(f"\n{'='*50}")
    print(f"Task: {cfg.task_id}")
    print(f"Instruction: {task_data.get('instruction', 'N/A')}")
    print(f"Ground truth: {task_data.get('ground_truth', {})}")
    print(f"{'='*50}")

    # Setup LLM call function
    if cfg.llm in ["gpt-3.5-turbo", "gpt-4o-mini"]:
        def gpt_llm_call(user_prompt, system_prompt, n_used, seed, logprobs=False):
            return chat_gpt(
                user_prompt=user_prompt,
                system_prompt=system_prompt,
                n_used=n_used,
                logprobs=logprobs,
                seed=seed,
                temperature=1,
                top_p=0.95,
                model_name=cfg.llm,
            )
        llm_call = gpt_llm_call
    elif cfg.llm in ["llama-3-70B", "llama-3-8B"]:
        def llama_llm_call(user_prompt, system_prompt, n_used, seed, logprobs=False):
            return llama(
                user_prompt=user_prompt,
                system_prompt=system_prompt,
                n_used=n_used,
                seed=seed,
                logprobs=logprobs,
                temperature=1,
                top_p=0.95,
                llm_name=cfg.llm,
            )
        llm_call = llama_llm_call
    else:
        raise ValueError(f"Invalid LLM: {cfg.llm}")

    # Setup Reasoner and Oracle based on strategy
    if cfg.strategy == "baseline":
        Reasoner = VizReasonerBase(
            llm_call=llm_call,
            task_data=task_data,
            seed=cfg.seed,
        )
        OracleAnswerer = OracleVizAnswerer(
            llm_call=llm_call,
            task_data=task_data,
            seed=cfg.seed,
        )

    elif cfg.strategy == "active-reasoning":
        Reasoner = ActiveVizReasoner(
            llm_call=llm_call,
            task_data=task_data,
            total_questions=cfg.total_questions,
            total_hypothesis=cfg.total_hypothesis,
            seed=cfg.seed,
            logprobs=False,
            unique_hs=True,
        )
        OracleAnswerer = OracleVizAnswerer(
            llm_call=llm_call,
            task_data=task_data,
            seed=cfg.seed,
        )

    elif cfg.strategy == "active-reasoning-binary":
        Reasoner = ActiveBinaryVizReasoner(
            llm_call=llm_call,
            task_data=task_data,
            total_questions=cfg.total_questions,
            total_hypothesis=cfg.total_hypothesis,
            seed=cfg.seed,
            logprobs=False,
            unique_hs=True,
        )
        OracleAnswerer = OracleBinaryVizAnswerer(
            llm_call=llm_call,
            task_data=task_data,
            seed=cfg.seed,
        )

    elif cfg.strategy == "tai":
        Reasoner = TAIVizReasoner(
            llm_call=llm_call,
            task_data=task_data,
            total_questions=cfg.total_questions,
            total_hypothesis=cfg.total_hypothesis,
            seed=cfg.seed,
            logprobs=False,
            unique_hs=True,
            embedding=getattr(cfg, "tai_embedding", "tfidf"),
            max_features=getattr(cfg, "tai_max_features", 2048),
            embedding_model=getattr(cfg, "tai_embedding_model", "text-embedding-3-large"),
            tau=getattr(cfg, "tai_tau", None),
        )
        OracleAnswerer = OracleVizAnswerer(
            llm_call=llm_call,
            task_data=task_data,
            seed=cfg.seed,
        )

    elif cfg.strategy == "tai-binary":
        Reasoner = TAIBinaryVizReasoner(
            llm_call=llm_call,
            task_data=task_data,
            total_questions=cfg.total_questions,
            total_hypothesis=cfg.total_hypothesis,
            seed=cfg.seed,
            logprobs=False,
            unique_hs=True,
            embedding=getattr(cfg, "tai_embedding", "tfidf"),
            max_features=getattr(cfg, "tai_max_features", 2048),
            embedding_model=getattr(cfg, "tai_embedding_model", "text-embedding-3-large"),
            tau=getattr(cfg, "tai_tau", None),
        )
        OracleAnswerer = OracleBinaryVizAnswerer(
            llm_call=llm_call,
            task_data=task_data,
            seed=cfg.seed,
        )

    else:
        raise ValueError(f"Invalid strategy: {cfg.strategy}")

    # Setup save path
    save_path = f"./results/{cfg.save_dir}/{cfg.task_id}/{cfg.strategy}/{cfg.llm}/iter_{cfg.seed}"
    
    handler = VizReasonerHandler(
        Reasoner=Reasoner,
        OracleAnswerer=OracleAnswerer,
        max_iter=cfg.max_iter,
        seed=cfg.seed,
        save_path=save_path,
    )

    create_directory(save_path)
    with open(f"{save_path}/config.yaml", "w") as f:
        OmegaConf.save(cfg, f)

    handler.run()
    handler.evaluate()


if __name__ == "__main__":
    os.environ["HYDRA_FULL_ERROR"] = "1"
    main()
