import os
import json
from typing import List, Dict, Tuple
from tqdm import tqdm
import torch
import argparse
from config import config
from reason_8 import Actor, Critic, MetaVerifier
from data_processing_8 import DataProcessor, Example
from factscore.factscorer import FactScorer
from utils import save_preferences
from datetime import datetime
import math
import random
import numpy as np
from factscore.lm import get_llm_calls
import signal
import sys
from collections import defaultdict

def signal_handler(sig, frame):
    print("Exiting...")
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
    sys.exit(0)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

class DPOPipeline:
    def __init__(
        self,
        epoch: int,
        weights_dir: str,
        fs_model_name: str,
        actor_gens: int,
        critic_gens: int,
        replay,
        use_critic,
        actor_threshold=None,
        window_size=None
    ):
        self.epoch = epoch
        self.weights_dir = weights_dir
        self.data_processor = DataProcessor(config)
        gamma = 1
        openai_key_path = os.environ.get("OPENAI_KEY_FILE", "keys/openai_key.txt")
        local_key_path = os.environ.get("LOCAL_KEY_FILE", "keys/local_api_key.txt")
        base_url_path = os.environ.get("BASE_URL_FILE", "keys/base_url.txt")
        self.fs = FactScorer(
            model_name="retrieval+llama",
            openai_key=openai_key_path,
            local_key=local_key_path,
            base_url=base_url_path,
            local=True,
            used_model_name=fs_model_name,
            track_llm_calls=True,
            gamma=gamma
        )
        print("gamma:", gamma)
        self.critic_correct_count = 0
        self.critic_total_count = 0
        self.actor_gens = actor_gens
        self.critic_gens = critic_gens
        self.replay = replay
        self.use_critic = use_critic
        self.actor_threshold = actor_threshold
        self.critic_sentence_scores: List[float] = []
        self.critic_correct_count = 0
        self.critic_total_count = 0
        self.critic_no_error_count = 0
        if window_size:
            self.window_size = window_size

    def setup_models(self):
        actor_path = (
            os.path.join(self.weights_dir, f"actor_{self.epoch-1}")
            if self.epoch > 1
            else config.actor_model
        )
        self.actor = Actor(actor_path, gpu_memory_utilization=0.8)
        self.actor.set_sampling_params(temperature=1.5, max_tokens=512)

    def setup_critic(self):
        critic_path = config.critic_model
        retrieval_db_default = "factscore_cache/enwiki-20230401.db"
        retrieval_db_path = os.environ.get("FACTSCORE_DB", retrieval_db_default)
        self.critic = Critic(
            model_path=critic_path,
            gpu_memory_utilization=0.8,
            retrieval_db_path=retrieval_db_path,
            retrieval_type="gtr-t5-large",
            k=5,
            include_context=False
        )
        self.critic.set_sampling_params(temperature=1.5, max_tokens=512)

    def setup_verifier(self):
        self.meta_verifier = MetaVerifier(model_path=config.critic_model, gpu_memory_utilization=0.8)

    def evaluate_test_set(self, epoch: int) -> float:
        print("\n--- evaluating test set ---\n")
        test_topics = []
        with open(config.test_path, 'r') as f:
            for line in f:
                data = json.loads(line)
                test_topics.append(data["topic"])
        paragraphs = self.actor.generate(test_topics)
        result = self.fs.get_score(test_topics, paragraphs, verbose=True)
        score = result["score"]
        print(f"Epoch {epoch} test results:")
        print(f"Total FactScore: {score}")
        return score

    def save_critic_trainset(self, topic_generations: Dict[str, List[str]], output_file: str):
        with open(output_file, 'w') as f:
            for topic, paragraphs in topic_generations.items():
                for paragraph in paragraphs:
                    example = {"paragraph": paragraph, "topic": topic}
                    f.write(json.dumps(example, ensure_ascii=False) + '\n')

    def save_with_replay(self, new_pairs, pref_dir, prefix="actor"):
        file_all = os.path.join(pref_dir, f"{prefix}_replay_all.jsonl")
        file_curr = os.path.join(pref_dir, f"{prefix}_{self.epoch}.jsonl")
        current_epoch_new_pairs_count = len(new_pairs)
        all_historical_pairs: List[Dict] = []
        if prefix == "critic" or self.replay != "none":
            try:
                with open(file_all, "r", encoding="utf-8") as f:
                    for line in f:
                        line = line.strip()
                        if line:
                            all_historical_pairs.append(json.loads(line))
            except FileNotFoundError:
                pass
        combined_pairs = all_historical_pairs + new_pairs
        final_buffer_for_all: List[Dict] = []
        sampled_pairs: List[Dict] = []
        print_prefix_log_msg = ""
        if prefix == "critic":
            print_prefix_log_msg = "[Critic Replay Rule]"
            if current_epoch_new_pairs_count == 0:
                final_buffer_for_all = []
                sampled_pairs = []
            else:
                critic_total_buffer_limit = current_epoch_new_pairs_count * 5
                if len(combined_pairs) > critic_total_buffer_limit:
                    final_buffer_for_all = combined_pairs[-critic_total_buffer_limit:]
                else:
                    final_buffer_for_all = combined_pairs
                critic_sample_size_for_epoch = current_epoch_new_pairs_count * 2
                actual_sample_size = min(critic_sample_size_for_epoch, len(final_buffer_for_all))
                if actual_sample_size > 0:
                    sampled_pairs = random.sample(final_buffer_for_all, actual_sample_size)
                else:
                    sampled_pairs = []
            save_preferences(final_buffer_for_all, file_all, mode='w')
        elif self.replay != "none":
            try:
                all_pairs = []
                with open(file_all, "r", encoding="utf-8") as f:
                    for line in f:
                        line = line.strip()
                        if not line:
                            continue
                        all_pairs.append(json.loads(line))
            except FileNotFoundError:
                all_pairs = []
            all_pairs.extend(new_pairs)
            if hasattr(self, "window_size") and self.window_size:
                if len(all_pairs) > self.window_size:
                    all_pairs = all_pairs[-self.window_size:]
            if self.replay == "rand":
                sample_size = len(new_pairs)
                sampled_pairs = random.sample(all_pairs, min(2520, len(all_pairs)))
            elif self.replay == "all":
                sampled_pairs = all_pairs
            save_preferences(all_pairs, file_all, mode='w')
        else:
            sampled_pairs = new_pairs
        save_preferences(sampled_pairs, file_curr, mode='w')
        print(f"{print_prefix_log_msg} ({prefix}) saved {len(sampled_pairs)} pairs to {file_curr}")
        if prefix == "critic" or (self.replay != "none" and prefix != "critic"):
            print(f"{print_prefix_log_msg} ({prefix}) total historical pairs in {file_all}: {len(final_buffer_for_all)}")

    def run(self, pref_dir: str, log_dir: str):
        try:
            N_total = 0
            N_none = 0
            N_mis = 0
            N_unscored = 0
            N_scored = 0
            N_high = 0
            N_low = 0
            missed_file_path = os.path.join(pref_dir, f"missed_extraction_{self.epoch}.jsonl")
            missed_f = open(missed_file_path, "w", encoding="utf-8")
            seed = 42 + self.epoch
            random.seed(seed)
            torch.manual_seed(seed)
            np.random.seed(seed)
            print("\n--- stage 1: setup actor ---\n")
            self.setup_models()
            if self.use_critic:
                print("(adversarial mode: critic enabled)")
            else:
                print("(baseline mode: critic disabled)")
            print("\n--- stage 2: actor generation ---\n")
            topics = self.data_processor.load_topics(config.trainset_1)
            print(f"Loading topics from: {config.trainset_1}")
            batch_out = self.actor.generate(topics, num_generations=self.actor_gens)
            flat = [(t, p) for t, ps in batch_out.items() for p in ps]
            critic_trainset_file = os.path.join(pref_dir, f"flat_{self.epoch}.jsonl")
            with open(critic_trainset_file, 'w', encoding='utf-8') as f:
                for topic, para in flat:
                    f.write(json.dumps({"topic": topic, "paragraph": para}, ensure_ascii=False) + "\n")
            print(f"Saved flattened data to: {critic_trainset_file}")
            all_actor_examples: List[Example] = []
            all_critic_examples: List[Example] = []
            if self.use_critic:
                del self.actor
                torch.cuda.empty_cache()
                print("\n--- stage 3: setup critic ---\n")
                print("\n--- stage 4: critic evaluation ---\n")
            else:
                print("\n--- stage 3: FS paired ---\n")
            if self.use_critic:
                self.setup_critic()
                max_crit_batch = 200
                critic_responses = []
                for i in range(0, len(flat), max_crit_batch):
                    batch = flat[i:i+max_crit_batch]
                    topics_list, paras_list = zip(*batch)
                    critic_responses.extend(
                        self.critic.evaluate(
                            list(topics_list),
                            list(paras_list),
                            num_evaluations=self.critic_gens
                        )
                    )
                sentences: List[str] = []
                facts: List[str] = []
                meta: List[Tuple[List[Dict], str, str, str, str, str]] = []
                for (topic, para), item in zip(flat, critic_responses):
                    for resp in item["responses"]:
                        proc = self.critic.process_response(resp, para)
                        self.critic_total_count += 1
                        if proc.get('incorrect_sentence'):
                            incorrect = proc.get('incorrect_sentence')
                            fact = proc.get('independent_fact')
                            sentences.append(incorrect)
                            facts.append(fact)
                            meta.append((
                                item["prompt_messages"],
                                resp,
                                topic,
                                para,
                                incorrect,
                                fact
                            ))
                        else:
                            N_none += 1
                            all_critic_examples.append(
                                Example(
                                    topic=topic,
                                    paragraph=para,
                                    label="rejected",
                                    critic_response=resp,
                                    incorrect_sentence="",
                                    critic_prompt_messages=item["prompt_messages"]
                                )
                            )
                del self.critic
                self.setup_verifier()
                sf_pairs = list(zip(sentences, facts))
                ver_results = self.meta_verifier.batch_verify(sf_pairs)
                filtered_meta = []
                for (prompt_msgs, resp, topic, para, incorrect, fact), (keep, reason) in zip(meta, ver_results):
                    if keep:
                        filtered_meta.append((prompt_msgs, resp, topic, para, incorrect, fact))
                    else:
                        N_mis += 1
                        all_critic_examples.append(
                            Example(
                                topic=topic,
                                paragraph=para,
                                label="rejected",
                                critic_response=resp,
                                incorrect_sentence=incorrect,
                                critic_prompt_messages=prompt_msgs
                            )
                        )
                        missed_entry = {
                            "epoch": self.epoch,
                            "topic": topic,
                            "paragraph": para,
                            "incorrect_sentence": incorrect,
                            "independent_fact": fact,
                            "critic_response": resp,
                            "reason": reason
                        }
                        missed_f.write(json.dumps(missed_entry, ensure_ascii=False) + "\n")
                filtered_file = os.path.join(pref_dir, f"filtered_meta_{self.epoch}.jsonl")
                with open(filtered_file, 'w', encoding='utf-8') as f:
                    for prompt_msgs, resp, topic, para, incorrect, fact in filtered_meta:
                        entry = {
                            "independent_fact": fact,
                            "incorrect_sentence": incorrect,
                            "paragraph": para,
                            "topic": topic,
                            "critic_response": resp,
                            "critic_prompt_messages": prompt_msgs
                        }
                        f.write(json.dumps(entry, ensure_ascii=False) + "\n")
                print(f"Saved filtered meta to: {filtered_file}")
                score_map: Dict[Tuple[str, str], List[float]] = {}
                topics_v = [m[2] for m in filtered_meta]
                paras_v = [m[3] for m in filtered_meta]
                facts_v = [m[5] for m in filtered_meta]
                atomic_facts = [[f] for f in facts_v]
                out = self.fs.get_score(
                    topics=topics_v,
                    generations=paras_v,
                    atomic_facts=atomic_facts,
                    verbose=False
                )
                scores = out["score"]
                for (prompt_msgs, resp, topic, para, incorrect, fact), sc in zip(filtered_meta, scores):
                    if sc is None or (isinstance(sc, float) and math.isnan(sc)):
                        N_unscored += 1
                        all_critic_examples.append(
                            Example(
                                topic=topic,
                                paragraph=para,
                                label="rejected",
                                critic_response=resp,
                                incorrect_sentence=incorrect,
                                critic_prompt_messages=prompt_msgs
                            )
                        )
                        continue
                    self.critic_sentence_scores.append(sc)
                    label = "rejected" if sc > 0.9 else "chosen"
                    if label == "rejected":
                        self.critic_correct_count += 1
                    all_critic_examples.append(
                        Example(
                            topic=topic,
                            paragraph=para,
                            label=label,
                            critic_response=resp,
                            incorrect_sentence=incorrect,
                            critic_prompt_messages=prompt_msgs
                        )
                    )
                    score_map.setdefault((topic, para), []).append(sc)
                for (topic, para), scs in score_map.items():
                    mean_sc = sum(scs) / len(scs)
                    all_actor_examples.append(
                        Example(topic=topic, paragraph=para, label=None, score=mean_sc)
                    )
            else:
                flat = [(t, p) for t, ps in batch_out.items() for p in ps]
                max_fs_batch = 10
                for i in range(0, len(flat), max_fs_batch):
                    batch = flat[i:i+max_fs_batch]
                    topics_list, paras_list = zip(*batch)
                    res = self.fs.get_score(list(topics_list), list(paras_list), verbose=False)
                    for (topic, para), sc in zip(batch, res["score"]):
                        all_actor_examples.append(
                            Example(topic=topic, paragraph=para, label=None, score=sc)
                        )
            if self.use_critic:
                del self.meta_verifier
            else:
                del self.actor
            topic_buckets = defaultdict(list)
            for ex in all_actor_examples:
                topic_buckets[ex.topic].append(ex)
            for topic, exs in topic_buckets.items():
                if self.actor_threshold is not None:
                    print("Use Threshold")
                    for ex in exs:
                        ex.label = "chosen" if (ex.score is not None and ex.score >= self.actor_threshold) else "rejected"
                else:
                    print("No Threshold")
                    for ex in exs:
                        ex.label = "chosen" if ex.score == 1.0 else "rejected"
            print("\n--- stage 5: create training pairs ---\n")
            actor_pairs, actor_stats = self.data_processor.create_actor_pairs(all_actor_examples)
            if self.use_critic:
                critic_pairs, critic_stats = self.data_processor.create_critic_pairs(all_critic_examples)
                formatted_critic_pairs = self.data_processor.format_critic_pairs(critic_pairs)
            print(f"Successfully formed actor pairs: {len(actor_pairs)}")
            if self.use_critic:
                print(f"Successfully formed critic pairs: {len(critic_pairs)}")
            stats_file = os.path.join(log_dir, 'failed_pairs.txt')
            current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            with open(stats_file, 'a') as f:
                f.write(f"\n{'='*50}\n")
                f.write(f"Epoch: {self.epoch} | Time: {current_time}\n")
                f.write(f"Part 1\n")
                f.write(f"\nActor pairs statistics:\n")
                f.write(f"Unable to form pairs due to:\n")
                f.write(f"- All chosen: {actor_stats['all_chosen']}\n")
                f.write(f"- All rejected: {actor_stats['all_rejected']}\n")
                f.write(f"- No labels: {actor_stats['no_labels']}\n")
                if self.use_critic:
                    f.write(f"\nCritic pairs statistics:\n")
                    f.write(f"Unable to form pairs due to:\n")
                    f.write(f"- All chosen: {critic_stats['all_chosen']}\n")
                    f.write(f"- All rejected: {critic_stats['all_rejected']}\n")
                    f.write(f"- No labels: {critic_stats['no_labels']}\n")
                    f.write(f"\nCritic total judgments: {self.critic_total_count}\n")
                    f.write(f"Critic correct judgments: {self.critic_correct_count}\n")
                    N_total = self.critic_total_count
                    f.write(f"No extraction  : {N_none} ({(0 if N_total == 0 else N_none/N_total):.1%})\n")
                    denom1 = (N_total - N_none)
                    f.write(f"Mis-extraction : {N_mis} ({(0 if denom1 <= 0 else N_mis/denom1):.1%})\n")
                    denom2 = (N_total - N_none - N_mis)
                    f.write(f"Unscored       : {N_unscored} ({(0 if denom2 <= 0 else N_unscored/denom2):.1%})\n")
                    if self.critic_sentence_scores:
                        avg_score = sum(self.critic_sentence_scores) / len(self.critic_sentence_scores)
                        f.write(f"Critic extracted sentence FactScore mean: {avg_score:.4f}\n")
                    else:
                        f.write("Critic extracted sentence FactScore mean: N/A\n")
            formatted_actor_pairs = self.data_processor.format_actor_pairs(actor_pairs)
            print(f"Number of formatted_actor_pairs: {len(formatted_actor_pairs)}")
            if self.use_critic:
                print(f"Number of formatted_critic_pairs: {len(formatted_critic_pairs)}")
            actor_file = os.path.join(pref_dir, f"actor_{self.epoch}.jsonl")
            print(f"Saving actor preferences to: {actor_file}")
            self.save_with_replay(formatted_actor_pairs, pref_dir, prefix="actor")
            if self.use_critic:
                self.save_with_replay(formatted_critic_pairs, pref_dir, prefix="critic")
            else:
                print("No Critic (baseline)")
            total_calls = get_llm_calls(reset=True)
            calls_log = os.path.join(log_dir, "llm_calls.txt")
            with open(calls_log, "a") as f:
                f.write(f"epoch {self.epoch}\t{total_calls}\n")
            print(f"[INFO] epoch {self.epoch}  LLM forward calls = {total_calls}")
        finally:
            if torch.distributed.is_initialized():
                torch.distributed.destroy_process_group()

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--epoch", type=int, required=True)
    parser.add_argument("--weights_dir", type=str, required=True)
    parser.add_argument("--pref_dir", type=str, required=True)
    parser.add_argument("--log_dir", type=str, required=True)
    parser.add_argument("--fs_model_name", type=str, default=None, help="FactScorer model name")
    parser.add_argument("--actor_gens", type=int, default=None, help="Number of generations per topic by Actor")
    parser.add_argument("--critic_gens", type=int, default=None, help="Number of judgments per paragraph by Critic")
    parser.add_argument("--replay", choices=["none", "rand", "all"], default="none", help="Replay strategy")
    parser.add_argument("--use_critic", action="store_true", help="Enable adversarial mode (critic)")
    parser.add_argument("--actor_threshold", type=float, default=None, help="Actor threshold for chosen examples")
    parser.add_argument("--window_size", type=int, default=5040, help="Replay sliding window size")
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)
    args = parser.parse_args()
    pipeline = DPOPipeline(
        args.epoch,
        args.weights_dir,
        args.fs_model_name,
        args.actor_gens,
        args.critic_gens,
        args.replay,
        args.use_critic,
        actor_threshold=args.actor_threshold,
        window_size=args.window_size
    )
    pipeline.run(args.pref_dir, args.log_dir)

if __name__ == "__main__":
    main()
