import time
import random
import numpy as np
import os
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import f1_score
from core.model.prompt import Prompt
from core.infra.llm import LLMApi
from core.alg.evaluator import Evaluator
from core.alg.mutator import Mutator
from core.alg.optimizer import Optimizer
from core.model.badcase_tracker import BadcaseTracker
from typing import List


class Controller:
    def __init__(self, initial_prompt: str, problem_description: str,
                 train_df: pd.DataFrame, test_df: pd.DataFrame,
                 num_iterations: int, population_size: int, is_f1: bool):

        self.evaluator: Evaluator = Evaluator(is_f1)
        self.test_evaluator: Evaluator = Evaluator(is_f1)
        self.population: List[Prompt] = [Prompt(initial_prompt, 0, 0, [], [], [])]
        self.mutator: Mutator = Mutator(problem_description, self.population)
        self.optimizer: Optimizer = Optimizer(self.evaluator)
        self.performance_history: List[float] = []
        self.test_f1_history: List[float] = []
        self.smoothed_pg: float = 0.011
        self.num_iterations: int = num_iterations
        self.population_size: int = population_size

        self.train_df: pd.DataFrame = train_df.copy()
        self.train_df["target"] = self.train_df["target"].astype(str).str.strip().str.upper()
        self.test_df: pd.DataFrame = test_df.copy()
        self.test_df["target"] = self.test_df["target"].astype(str).str.strip().str.upper()

        self.ensemble: List[Prompt] = [Prompt(initial_prompt, 0, 0, [], [], [])]
        self.ensemble_weights: List[float] = [1.0]
        self.badcase_tracker: BadcaseTracker = BadcaseTracker()

    def sample_fixed_datasets(self, train_sample_size=7, test_sample_size=3, iteration=0):
        seed = 42 + iteration
        train_n = min(train_sample_size, len(self.train_df))
        test_n = min(test_sample_size, len(self.test_df))
        current_train_df = self.train_df.sample(n=train_n, random_state=seed).copy()
        current_test_df = self.test_df.sample(n=test_n, random_state=seed).copy()
        return current_train_df, current_test_df

    def generate_candidates(self, num_candidates=3):
        def select_parents(population, num_best=3, num_random=2):
            population = list(population)
            n = len(population)
            if n <= num_best + num_random:
                return population
            # 1. 按 F1 由高到低排序，选前3
            scored = [(p.f1 if p.f1 is not None else -float('inf'), p) for p in population]
            scored.sort(key=lambda x: x[0], reverse=True)
            best_parents = [p for (_, p) in scored[:num_best]]
            rest = [p for (_, p) in scored[num_best:]]
            random_parents = random.sample(rest, num_random)
            return best_parents + random_parents

        self.mutator.population = self.population
        candidates = []
        parents = select_parents(self.population, num_best=3, num_random=2)

        for parent in parents:
            for _ in range(num_candidates):
                candidate = self.mutator.badcase_reflection(parent)
                candidates.append(candidate)
            candidate = self.mutator.direct_mutation(parent)
            candidates.append(candidate)

        return candidates, "badcase_reflection + direct_mutation + zero_order_generation"

    def evaluate_ensemble(self, dataset: pd.DataFrame) -> float:
        predictions, targets = [], []
        for input_text, target in dataset[["input", "target"]].itertuples(index=False, name=None):
            pred = self.predict(input_text)
            predictions.append(pred)
            targets.append(target)
        return f1_score(targets, predictions, average='macro')

    def run(self):
        for f in ("embedding_cache.pkl", "generate_cache.pkl"):
            if os.path.exists(f):
                os.remove(f)
        print("Cache files cleared locally.")
        print("开始执行迭代...")

        for iteration in range(self.num_iterations):
            start_time = time.time()

            current_train_df, current_test_df = self.sample_fixed_datasets(iteration=iteration)
            candidates, strategy = self.generate_candidates()
            print(f"\n=== 第 {iteration} 次迭代 ===")
            print(f"选择的策略: {strategy}")
            print(f"候选 prompt 数量: {len(candidates)}")
            print(f"当前种群大小: {len(self.population)}")
            print(f"平滑性能梯度: {self.smoothed_pg:.6f}")

            for p in candidates:
                self.optimizer.compute_embedding(p)

            selected_bayesian_list = self.optimizer.bayesian_select(candidates, top_n=3)
            selected_mab_list = self.optimizer.mab_select(candidates, top_n=3)

            for prompt in selected_bayesian_list:
                f1 = self.evaluator.evaluate(prompt, current_train_df)
                prompt.f1 = f1
                self.optimizer.evaluated_prompts.append((self.optimizer.compute_embedding(prompt), f1))

            for prompt in selected_mab_list:
                f1 = self.evaluator.evaluate(prompt, current_train_df)
                prompt.f1 = f1
                if hasattr(prompt, 'cluster_id'):
                    self.optimizer.arms[prompt.cluster_id].pull(f1)
                    self.optimizer.total_pulls += 1

            self.population.extend(selected_bayesian_list + selected_mab_list)

            if iteration >= 1:
                badcase_prompt = self.mutator.badcase_tracker_generation(self.badcase_tracker, top_badcase_k=10)
                badcase_prompt.f1 = self.evaluator.evaluate(badcase_prompt, current_train_df)
                self.population.append(badcase_prompt)
                self.badcase_tracker = BadcaseTracker()

            self.population = sorted(self.population, key=lambda p: p.f1 if p.f1 is not None else 0, reverse=True)[:self.population_size]

            for prompt in self.population:
                if prompt.f1 is not None and hasattr(prompt, "bad_cases"):
                    for input_text, target in prompt.bad_cases:
                        self.badcase_tracker.update(input_text, prompt.text)

            self.select_ensemble(data=current_test_df)

            ensemble_train_f1 = self.evaluate_ensemble(current_train_df)
            ensemble_test_f1 = self.evaluate_ensemble(current_test_df)

            self.performance_history.append(ensemble_train_f1)
            self.test_f1_history.append(ensemble_test_f1)

            print(f"集成 Prompt (训练 F1: {ensemble_train_f1:.4f}, 测试 F1: {ensemble_test_f1:.4f})")
            print("\n当前集成 Prompt 及其投票权重:")
            for i, (prompt, weight) in enumerate(zip(self.ensemble, self.ensemble_weights)):
                print(f"Prompt {i + 1} (权重: {weight:.4f}):")
                print(f"{prompt.text}")
            print(f"迭代时间: {time.time() - start_time:.2f} 秒")

            if len(self.test_f1_history) >= 2:
                pg = (self.test_f1_history[-1] - self.test_f1_history[-2])
                alpha = 0.5
                self.smoothed_pg = alpha * pg + (1 - alpha) * self.smoothed_pg

            if self.should_early_stop():
                print(f"在第 {iteration} 次迭代触发提前停止")
                break

        self.select_ensemble(data=current_test_df)
        final_ensemble_train_f1 = self.evaluate_ensemble(current_train_df)
        final_ensemble_test_f1 = self.evaluate_ensemble(current_test_df)
        print(f"\n=== 最终结果 ===")
        print(f"最终集成 Prompt (训练 F1: {final_ensemble_train_f1:.4f}, 测试 F1: {final_ensemble_test_f1:.4f})")
        print("\n最终集成 Prompt 及其权重:")
        for i, (prompt, weight) in enumerate(zip(self.ensemble, self.ensemble_weights)):
            print(f"Prompt {i + 1} (权重: {weight:.4f}):")
            print(f"{prompt.text}")

    def should_early_stop(self, k=10, m=5, epsilon=0.001):
        if len(self.performance_history) < k:
            return False
        if all(self.performance_history[-i] < self.performance_history[-i - 1] for i in range(1, m)):
            return True
        if len(self.performance_history) >= m and all(0 < self.performance_history[-i] - self.performance_history[-i - 1] < epsilon for i in range(1, m)):
            return True
        return False

    def select_ensemble(self, N=3, M=10, data=None):
        top_prompts = sorted([p for p in self.population if p.f1 is not None], key=lambda p: p.f1, reverse=True)[:M]
        embeddings = [self.optimizer.compute_embedding(p) for p in top_prompts]
        unique_embeddings, indices = np.unique(embeddings, axis=0, return_index=True)
        n_clusters = min(N, len(unique_embeddings))
        kmeans = KMeans(n_clusters=n_clusters, init='k-means++')
        labels = kmeans.fit_predict(embeddings)
        self.ensemble = []
        for i in range(n_clusters):
            cluster_prompts = [p for p, l in zip(top_prompts, labels) if l == i]
            if cluster_prompts:
                best_in_cluster = max(cluster_prompts, key=lambda p: p.f1)
                self.ensemble.append(best_in_cluster)
        self.ensemble_weights = self.compute_ensemble_weights(top_prompts=top_prompts, M=M, data=data)

    def compute_ensemble_weights(self, top_prompts, M=10, data=None):
        ensemble_size = len(self.ensemble)
        if ensemble_size > 0:
            return np.ones(ensemble_size) / ensemble_size
        return np.ones(M) / M

    def predict(self, input_text):
        responses = []
        llm_api = LLMApi()
        for prompt in self.ensemble:
            messages = [{"role": "system", "content": prompt.text}, {"role": "user", "content": input_text}]
            response = llm_api.generate(messages)
            prediction = self.evaluator.parse_response(response)
            responses.append(prediction)

        valid_responses = [(pred, weight) for pred, weight in zip(responses, self.ensemble_weights[:len(responses)]) if pred in ["YES", "NO"]]
        if not valid_responses:
            return "NO"

        votes = {}
        for pred, weight in valid_responses:
            votes[pred] = votes.get(pred, 0) + weight
        return max(votes, key=votes.get)
