import json
import os
import numpy as np
from sklearn.mixture import GaussianMixture
from scipy.stats import norm
from scipy.optimize import brentq
from tqdm import tqdm
from joblib import Parallel, delayed
import argparse

if hasattr(np, 'trapezoid'):
    integrate_func = np.trapezoid
elif hasattr(np, 'trapz'):
    integrate_func = np.trapz

class TCAP:
    def __init__(self, file_path, train_file=None, output_file=None):
        self.file_path = file_path
        self.train_file = train_file
        self.output_file = output_file
        self.data_dict = {}
        self.norm_data_lookup = None
        self.tasks = []
        self.top_results = []

        self.L_sens = 8
        self.H_sens = 10
        self.vote_threshold = 1e-4

    @staticmethod
    def _gmm_group(x):
        # Step 2: GMM Profiling
        grid_points = 4096
        x = x.reshape(-1, 1)
        results = []

        for n in range(1, 6):
            gmm = GaussianMixture(
                n_components=n,
                random_state=42,
                reg_covar=1e-6
            ).fit(x)

            results.append({
                "n": n,
                "gmm": gmm,
                "aic": gmm.aic(x),
                "bic": gmm.bic(x),
                "weights": gmm.weights_.copy(),
                "means": gmm.means_.flatten(),
                "stds": np.sqrt(gmm.covariances_.flatten())
            })

        def has_invalid_component(weights, means, stds, w_th=0.02, extreme_th=0.5, sigma_k=6):
            for w, m, s in zip(weights, means, stds):
                if w < w_th and (s * sigma_k) > extreme_th:
                    return True
            return False

        def improved_enough(prev, curr, delta_aic=10, delta_bic=10):
            return (
                    prev["aic"] - curr["aic"] > delta_aic or
                    prev["bic"] - curr["bic"] > delta_bic
            )

        best_n = 1
        for i in range(1, len(results)):
            prev = results[i - 1]
            curr = results[i]
            better = improved_enough(prev, curr)
            invalid = has_invalid_component(curr["weights"], curr["means"], curr["stds"])

            if better and not invalid:
                best_n = curr["n"]
            else:
                break

        res = next(r for r in results if r["n"] == best_n)
        gmm = res["gmm"]
        means = res["means"].astype(np.float64)
        stds = res["stds"].astype(np.float64)
        weights = res["weights"].astype(np.float64)

        K = len(means)
        if K <= 1:
            return gmm, np.array([0])

        inv_sqrt2pi = 1.0 / np.sqrt(2.0 * np.pi)

        order = np.argsort(means)
        means = means[order]
        stds = stds[order]
        weights = weights[order]

        cut_scores = []
        for m in range(K - 1):
            i, j = m, m + 1
            dij = abs(means[i] - means[j]) / np.sqrt(stds[i] ** 2 + stds[j] ** 2 + 1e-12)

            left = min(means[i] - 6 * stds[i], means[j] - 6 * stds[j])
            right = max(means[i] + 6 * stds[i], means[j] + 6 * stds[j])

            if not np.isfinite(left) or not np.isfinite(right) or right <= left:
                rel_overlap = 0.0
            else:
                xs = np.linspace(left, right, grid_points)
                fi = weights[i] * (inv_sqrt2pi / stds[i]) * np.exp(-0.5 * ((xs - means[i]) / stds[i]) ** 2)
                fj = weights[j] * (inv_sqrt2pi / stds[j]) * np.exp(-0.5 * ((xs - means[j]) / stds[j]) ** 2)
                overlap = float(integrate_func(np.minimum(fi, fj), xs))
                rel_overlap = overlap / (min(weights[i], weights[j]) + 1e-12)

            score = dij + (1.0 - rel_overlap)
            cut_scores.append(score)

        m_star = int(np.argmax(cut_scores))
        group_left = np.arange(0, m_star + 1)
        group_right = np.arange(m_star + 1, K)

        w_left = weights[group_left].sum()
        w_right = weights[group_right].sum()
        target_group_sorted = group_left if w_left < w_right else group_right

        target_group = order[target_group_sorted]

        return gmm, target_group

    @staticmethod
    def _process_single_head(x_norm, lh_tuple, num_layers):
        grid_points = 4096
        if x_norm is None:
            return {'score': 0.0, 'valid': False}

        gmm, target_group = TCAP._gmm_group(x_norm)

        means = gmm.means_.reshape(-1)
        covs = gmm.covariances_.reshape(-1)
        stds = np.sqrt(np.maximum(covs, 1e-12))
        weights = gmm.weights_.reshape(-1)

        K = len(means)
        if K == 1:
            return {'lh': lh_tuple, 'score': 0.0, 'gmm': gmm, 'target_group': target_group, 'valid': True}

        min_x = np.min(means - 6 * stds)
        max_x = np.max(means + 6 * stds)
        x_grid = np.linspace(min_x, max_x, grid_points)

        pdf_target = np.zeros_like(x_grid)
        pdf_background = np.zeros_like(x_grid)

        target_set = set(target_group)

        for k in range(K):
            p_k = weights[k] * norm.pdf(x_grid, loc=means[k], scale=stds[k])
            if k in target_set:
                pdf_target += p_k
            else:
                pdf_background += p_k

        min_curve = np.minimum(pdf_target, pdf_background)
        overlap_area = integrate_func(min_curve, x_grid)
        overlap_area = max(overlap_area, 1e-4)

        # Step 2: Calculate Separation Score
        separation_score = 1.0 / overlap_area

        return {
            'lh': lh_tuple,
            'score': separation_score,
            'gmm': gmm,
            'target_group': target_group,
            'valid': True
        }

    @staticmethod
    def _dawid_skene_binary(L, max_iter=50, tol=1e-6, eps=1e-6):
        prior = 0.1
        V, N = L.shape
        obs = ~np.isnan(L)

        votes = np.where(obs, L, 0.0)
        cnt = obs.sum(axis=0).clip(min=1)
        post = np.clip(votes.sum(axis=0) / cnt, eps, 1 - eps)

        pi = float(np.clip(prior, eps, 1 - eps))

        for _ in range(max_iter):
            y1, y0 = post, 1.0 - post

            tpr = np.zeros(V)
            fpr = np.zeros(V)
            for v in range(V):
                ov = obs[v]
                if ov.sum() == 0: continue
                w_y1, w_y0 = y1[ov], y0[ov]
                l_v = L[v, ov]
                tpr[v] = (w_y1 * l_v).sum() / (w_y1.sum() + eps)
                fpr[v] = (w_y0 * l_v).sum() / (w_y0.sum() + eps)

            tpr, fpr = np.clip(tpr, eps, 1 - eps), np.clip(fpr, eps, 1 - eps)

            logp1 = np.full(N, np.log(pi))
            logp0 = np.full(N, np.log(1.0 - pi))

            for v in range(V):
                ov = obs[v]
                if ov.sum() == 0: continue
                lv = L[v, ov]
                logp1[ov] += lv * np.log(tpr[v]) + (1 - lv) * np.log(1 - tpr[v])
                logp0[ov] += lv * np.log(fpr[v]) + (1 - lv) * np.log(1 - fpr[v])

            m = np.maximum(logp1, logp0)
            p1, p0 = np.exp(logp1 - m), np.exp(logp0 - m)
            post_new = np.clip(p1 / (p1 + p0), eps, 1 - eps)

            if np.abs(post_new - post).sum() < tol: break
            post = post_new
            pi = np.clip(post.mean(), eps, 1 - eps)

        return post, pi, tpr, fpr

    def load_data(self):
        if not os.path.exists(self.file_path):
            raise FileNotFoundError(f"File not found: {self.file_path}")

        with open(self.file_path, 'r') as f:
            raw_data = [json.loads(line) for line in f]

        full_map = np.array([d['tcap_map'] for d in raw_data])
        tcap_map = full_map[:, -self.L_sens:, :, :]

        self.data_dict = {
            "tcap_map": tcap_map,
            "question_id": np.array([d['question_id'] for d in raw_data]),
            "answer": np.array([d['answer'] for d in raw_data]),
        }
        print(f"Loaded {self.file_path}: Shape {full_map.shape}")

    def preprocess(self, part='sys'):
        # Step 1: Extraction & Normalization
        # Extracts the 'System' component attention and applies Min-Max normalization to ensure numerical stability.
        part_idx = {'sys': 0, 'img': 1, 'usr': 2}[part]
        target_data = self.data_dict['tcap_map'][:, :, :, part_idx].transpose(1, 2, 0)
        L, H, N = target_data.shape

        mins = target_data.min(axis=2, keepdims=True)
        maxs = target_data.max(axis=2, keepdims=True)
        diffs = maxs - mins
        valid_mask = (diffs > 1e-12).squeeze(-1)

        self.norm_data_lookup = [[None for _ in range(H)] for _ in range(L)]

        unique_task_map = {}

        for li in range(L):
            for hi in range(H):
                if valid_mask[li, hi]:
                    norm_x = (target_data[li, hi] - mins[li, hi, 0]) / diffs[li, hi, 0]
                    self.norm_data_lookup[li][hi] = norm_x

                    key = (li, norm_x.tobytes())

                    if key not in unique_task_map:
                        unique_task_map[key] = {
                            'data': norm_x,
                            'heads': []
                        }
                    unique_task_map[key]['heads'].append((li, hi))
                else:
                    pass

        self.tasks = list(unique_task_map.values())

    def scan_heads(self, n_jobs=-1):
        num_layers = self.data_dict['tcap_map'].shape[1]

        unique_results = Parallel(n_jobs=n_jobs)(
            delayed(TCAP._process_single_head)(
                task['data'],
                task['heads'][0],
                num_layers
            )
            for task in tqdm(self.tasks, desc="Scanning Heads (Unique)", leave=False)
        )

        all_results = []
        for res, task in zip(unique_results, self.tasks):
            if not res.get('valid', False):
                continue

            all_results.append(res)

            heads_list = task['heads']
            if len(heads_list) > 1:
                for other_lh in heads_list[1:]:
                    new_res = res.copy()
                    new_res['lh'] = other_lh
                    all_results.append(new_res)

        all_results.sort(key=lambda x: x['score'], reverse=True)
        # Step 2: Head Selection
        self.top_results = all_results[:self.H_sens]

    def vote_and_evaluate(self):
        if not self.top_results:
            return 0.0, 0.0, 0.0, 0, 0, 0, 0

        judges = []
        for res in self.top_results:
            layer, head = res['lh']
            x_norm = self.norm_data_lookup[layer][head]
            if x_norm is None:
                judges.append(np.zeros(len(self.data_dict['question_id']), dtype=bool))
                continue

            probs = res['gmm'].predict_proba(x_norm.reshape(-1, 1))
            target_prob_sum = probs[:, res['target_group']].sum(axis=1)
            judges.append(target_prob_sum > self.vote_threshold)

        # Step 3: Cleaning
        post, _, _, _ = TCAP._dawid_skene_binary(np.stack(judges).astype(float))
        judge = post > 0.5

        if self.train_file and self.output_file:
            print(f"Cleaning: Reading {self.train_file} -> Writing clean data to {self.output_file}")

            poisoned_ids = set()
            for q_id, is_poisoned in zip(self.data_dict['question_id'], judge):
                if is_poisoned:
                    poisoned_ids.add(str(q_id))
            
            with open(self.train_file, 'r', encoding='utf-8') as f_in, open(self.output_file, 'w', encoding='utf-8') as f_out:
                train_suffix = self.train_file.split('.')[-1]
                if train_suffix == 'json':
                    train_data = json.load(f_in)
                    cleaned_data = []
                    for line in train_data:
                        if line['id'] not in poisoned_ids:
                            cleaned_data.append(line)
                    json.dump(cleaned_data, f_out)
                elif train_suffix == 'jsonl':
                    train_data = [json.loads(line) for line in f_in]
                    for line in train_data:
                        if line['id'] not in poisoned_ids:
                            f_out.write(json.dumps(line) + '\n')
                else:
                    print('Only support JSON or JSONL for training data.')
                    exit(0)
                        
        qids, answers = self.data_dict['question_id'], self.data_dict['answer']
        ends_poisoned = np.char.endswith(qids.astype(str), 'poisoned')
        is_bd = ends_poisoned

        tp = int((is_bd & judge).sum())
        fp = int((~is_bd & judge).sum())
        lf = int((is_bd & ~judge).sum())
        tt = int(is_bd.sum())

        prec = tp / (tp + fp + 1e-12)
        rec = tp / (tt + 1e-12)
        f1 = 2 * prec * rec / (prec + rec + 1e-12)

        return f1, prec, rec, lf, tt, tp, fp

    def run(self):
        self.load_data()
        self.preprocess()
        self.scan_heads()
        return self.vote_and_evaluate()


if __name__ == '__main__':
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--tcap-file", type=str, required=True)
    argparser.add_argument("--train-file", type=str, default=None)
    argparser.add_argument("--output-file", type=str, default=None)
    args = argparser.parse_args()

    detector = TCAP(args.tcap_file, args.train_file, args.output_file)

    f1, prec, rec, lf, tt, tp, fp = detector.run()

    print(f'Precision: {prec * 100:.2f}%')
    print(f'Recall:    {rec * 100:.2f}%')
    print(f'F1 Score:  {f1 * 100:.2f}')
    print(f'TP Count:  {tp} / (Total {tt})')
    print(f'FP Count:  {fp}')
