import random
import numpy as np
import json
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import ElectraForSequenceClassification, BertForSequenceClassification


from tqdm import tqdm
import argparse
import os
from src.utils import load_beir_datasets
import difflib
import re


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def load_json(path):
    with open(path, 'r', encoding='utf-8') as f:
        return json.load(f)


def save_json(obj, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w', encoding='utf-8') as f:
        json.dump(obj, f, indent=2, ensure_ascii=False)


def generate_all_sentences(S, V, subset_z, k, alternative=None):
    SS = []
    if k != 1:
        raise NotImplementedError
    
    if alternative is None:
        SS.append(S)

    for z in subset_z:
        if z % 2 == 0:
            for v_char_code in V:
                if v_char_code == -1: continue
                s_new = S[:z // 2] + chr(v_char_code) + S[z // 2:]
                SS.append(s_new)
        else:
            for v_char_code in V:
                if z // 2 >= len(S): continue
                if v_char_code == -1:
                    s_new = S[:z // 2] + S[z // 2 + 1:]
                else:
                    s_new = S[:z // 2] + chr(v_char_code) + S[z // 2 + 1:]
                SS.append(s_new)
    return list(set(SS))


class CharmerReranker:
    def __init__(self, model, tokenizer, args):
        self.model = model
        self.tokenizer = tokenizer
        self.args = args
        self.device = args.device
        self.V = self.get_vocabulary()

    def get_vocabulary(self):
        V = [-1] + [ord(c) for c in '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ']
        return V

    def _freeze_contiguous_block(self, constraints, passage_len, start_index, length, window_size):
        end_index = start_index + length
        
        freeze_start = max(0, start_index - window_size)
        freeze_end = min(passage_len, end_index + window_size)

        for i in range(freeze_start, freeze_end):
            constraints[2 * i] = 0
            constraints[2 * i + 1] = 0
        
        if 2 * freeze_end < len(constraints):
            constraints[2 * freeze_end] = 0

    def _freeze_region(self, constraints, passage, text, window_size, word_window_size=7):
        if not text:
            return

        passage_len = len(passage)
        passage_lower = passage.lower()
        text_lower = text.lower()

        try:
            start_index = passage.index(text)
            self._freeze_contiguous_block(constraints, passage_len, start_index, len(text), window_size)
            return
        except ValueError:
            pass

        try:
            start_index = passage_lower.index(text_lower)
            self._freeze_contiguous_block(constraints, passage_len, start_index, len(text), window_size)
            return
        except ValueError:
            pass
        
        words_to_find = re.findall(r'\w+', text_lower)
        if not words_to_find:
            return

        for word in set(words_to_find):
            pattern = r'\b' + re.escape(word) + r'\b'
            for match in re.finditer(pattern, passage, re.IGNORECASE):
                start_index = match.start()
                self._freeze_contiguous_block(constraints, passage_len, start_index, len(word), word_window_size)

    def get_initial_constraints(self, passage, text_to_exclude, window_size=15):
        constraints = [1] * (2 * len(passage) + 1)
        
        if isinstance(text_to_exclude, str):
            self._freeze_region(constraints, passage, text_to_exclude, window_size)

        elif isinstance(text_to_exclude, list):
            windows = {
                "question": 0,
                "incorrect": window_size,
                "correct": 0,
            }
            
            if len(text_to_exclude) == 3:
                question, incorrect_text, correct_text = text_to_exclude
                self._freeze_region(constraints, passage, question, windows['question'])
                self._freeze_region(constraints, passage, incorrect_text, windows['incorrect'])
                self._freeze_region(constraints, passage, correct_text, windows['correct'])

            elif len(text_to_exclude) == 2:
                question, incorrect_text = text_to_exclude
                self._freeze_region(constraints, passage, question, windows['question'])
                self._freeze_region(constraints, passage, incorrect_text, windows['incorrect'])

            else:
                for text in text_to_exclude:
                    self._freeze_region(constraints, passage, text, window_size)
        
        return constraints

    def get_top_n_locations(self, question, passage, constraints_loc, n=200):
        available_indices = [i for i, c in enumerate(constraints_loc) if c == 1]
        if not available_indices:
            return []

        if len(available_indices) > 800:
            subset_z_eval = random.sample(available_indices, 800)
        else:
            subset_z_eval = available_indices
            
        SS = generate_all_sentences(passage, [ord(' ')], subset_z_eval, 1)
        if len(SS) <= 1:
            return []

        scores = []
        batch_size = 256
        with torch.no_grad():
            inputs_orig = self.tokenizer([(question, passage)], padding=True, truncation=True, return_tensors="pt").to(self.device)
            original_score = self.model(**inputs_orig).logits.squeeze()

            if isinstance(self.model, ElectraForSequenceClassification):
                last_hidden_state = self.model.electra(
                    input_ids=inputs_orig['input_ids'],
                    attention_mask=inputs_orig['attention_mask'],
                    token_type_ids=inputs_orig['token_type_ids']
                )[0]

                manual_score = self.model.classifier(last_hidden_state).squeeze()

            else:
                bert_outputs = self.model.bert(
                    input_ids=inputs_orig['input_ids'],
                    attention_mask=inputs_orig['attention_mask'],
                    token_type_ids=inputs_orig['token_type_ids']
                )
                pooler_output = bert_outputs[1]
                dropout_output = self.model.dropout(pooler_output)
                manual_score = self.model.classifier(dropout_output).squeeze()


            for i in tqdm(range(1, len(SS), batch_size), desc="Finding best locations", leave=False):
                batch_ss = SS[i:i+batch_size]
                pairs = [(question, s) for s in batch_ss]
                inputs = self.tokenizer(pairs, padding=True, truncation=True, return_tensors="pt").to(self.device)
                batch_scores = self.model(**inputs).logits.squeeze(-1)
                scores.append(batch_scores)
        
        if not scores:
            return []
        scores = torch.cat(scores)
        
        score_diffs = original_score - scores
        
        top_k_indices = torch.topk(score_diffs, min(n, len(score_diffs))).indices
        
        best_locations = [subset_z_eval[i] for i in top_k_indices]
        
        return best_locations

    def simplex_projection(self, u_np):
        if not isinstance(u_np, np.ndarray):
             u_np = np.array(u_np)
        
        u_sorted = np.sort(u_np)[::-1]
        cssv = np.cumsum(u_sorted) - 1
        ind = np.arange(len(u_sorted)) + 1
        cond = u_sorted - cssv / ind > 0
        rho = ind[cond][-1]
        theta = cssv[cond][-1] / rho
        return np.maximum(u_np - theta, 0)

    def attack_step(self, question, passage, subset_z, return_top_n=1):
        # will be released in the future
        return top_results


    def attack(self, question, passage, text_to_exclude, target_score=0.95):
        with torch.no_grad():
            inputs = self.tokenizer([(question, passage)], padding=True, truncation=True, return_tensors="pt").to(self.device)
            initial_score_logit = self.model(**inputs).logits.item()
            initial_score_sigmoid = torch.sigmoid(torch.tensor(initial_score_logit)).item()

        print(f"Initial score: {initial_score_sigmoid:.4f}")
        if initial_score_sigmoid > target_score:
            print("Initial score already above target. Skipping.")
            return passage, initial_score_sigmoid
        
        initial_constraints = self.get_initial_constraints(passage, text_to_exclude)
        beam = [(initial_score_logit, passage, initial_constraints)]
        
        champion_score_logit = initial_score_logit
        champion_passage = passage
        
        champion_not_updated_for = 0

        for i in range(self.args.k):
            print(f"\n--- Attack Iteration {i+1}/{self.args.k} ---")
            
            all_candidates_this_step = []
            for old_score, old_passage, old_constraints in beam:
                subset_z = self.get_top_n_locations(question, old_passage, old_constraints, n=100)
                if not subset_z:
                    continue

                top_n = self.args.eval_num if self.args.eval_num > 0 else -1
                new_candidates = self.attack_step(question, old_passage, subset_z, return_top_n=top_n)

                for new_score, new_passage in new_candidates:
                    all_candidates_this_step.append({
                        "score": new_score,
                        "new_passage": new_passage,
                        "parent_passage": old_passage,
                        "parent_constraints": old_constraints 
                    })
            
            if not all_candidates_this_step:
                print("No new candidates were generated from any beam. Stopping.")
                break

            unique_candidates = {}
            for cand_data in all_candidates_this_step:
                new_passage = cand_data["new_passage"]
                if new_passage not in unique_candidates or cand_data["score"] > unique_candidates[new_passage]["score"]:
                    unique_candidates[new_passage] = cand_data
            
            sorted_candidates = sorted(unique_candidates.values(), key=lambda x: x['score'], reverse=True)
            top_k_candidates = sorted_candidates[:self.args.beam_width]

            if not top_k_candidates:
                print("No candidates left after sorting. Stopping attack.")
                break

            new_beam = []
            for data in top_k_candidates:
                new_constraints = self.get_initial_constraints(data["new_passage"], text_to_exclude)
                new_beam.append((data["score"], data["new_passage"], new_constraints))
            
            beam = new_beam

            if not beam:
                print("Beam is empty. Stopping attack.")
                break

            best_iter_score_logit, best_iter_passage, _ = beam[0]
            best_iter_score_sigmoid = torch.sigmoid(torch.tensor(best_iter_score_logit)).item()

            print(f"Iter {i+1} Best Passage in Beam (preview): {best_iter_passage[:100]}...")
            print(f"Iter {i+1} Best Score in Beam: {best_iter_score_sigmoid:.4f}")
            
            champion_score_sigmoid = torch.sigmoid(torch.tensor(champion_score_logit)).item()
            if best_iter_score_sigmoid - champion_score_sigmoid > self.args.score_threshold:
                print(f"Score improved beyond threshold ({best_iter_score_sigmoid:.4f} > {champion_score_sigmoid:.4f} + {self.args.score_threshold}). Champion updated.")
                champion_score_logit = best_iter_score_logit
                champion_passage = best_iter_passage
                champion_not_updated_for = 0
            else:
                champion_not_updated_for += 1
                print(f"Champion not updated for {champion_not_updated_for} consecutive iteration(s).")

            if champion_not_updated_for >= self.args.patience:
                print(f"\nChampion has not been updated for {self.args.patience} consecutive iterations. Stopping attack early.")
                break
            
            if best_iter_score_sigmoid > target_score:
                print("Target score reached. Stopping.")
                if best_iter_score_sigmoid - champion_score_sigmoid > self.args.score_threshold:
                    champion_score_logit = best_iter_score_logit
                    champion_passage = best_iter_passage
                break

        final_score_sigmoid = torch.sigmoid(torch.tensor(champion_score_logit)).item()
        print("\n--- Attack Finished ---")
        print(f"Original Passage (preview): {passage[:100]}...")
        print(f"Final Champion Passage (preview): {champion_passage[:100]}...")
        print(f"Initial Score: {initial_score_sigmoid:.4f} -> Final Champion Score: {final_score_sigmoid:.4f}")
        return champion_passage, final_score_sigmoid


def main(args):
    set_seed(args.seed)
    args.device = "cuda" if torch.cuda.is_available() and args.gpu >= 0 else "cpu"
    if args.device == 'cuda':
        torch.cuda.set_device(args.gpu)

    # Load model
    print("Loading Cross-Encoder model...")
    model_name = f"type_your_model_name_here"  # Replace with your model name
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name).to(args.device)
    print(model)
    model.eval()

    # Load data
    input_path = f'results/adv_corpus_generated/{args.poison_method}/{args.eval_dataset}.json'
    output_path = f'results/adv_corpus_generated/{args.poison_method}_charmer_score{args.target_score}_{args.rerank_model}_epochs2_10_mask_cons/{args.eval_dataset}.json'
    print(f"Loading data from: {input_path}")
    data = load_json(input_path)

    if args.num > 0:
        data = dict(list(data.items())[:args.num])
        print(f"Processing {args.num} examples")

    charmer = CharmerReranker(model, tokenizer, args)

    results = {}
    for query_id, ex in tqdm(data.items(), desc="Processing examples"):
        question = ex["question"]
        incorrect_answer = ex["incorrect answer"]
        
        for passage in tqdm(ex.get("adv_texts", []), desc="Processing adv_texts", leave=False):
            new_passage, _ = charmer.attack(question, passage, text_to_exclude, target_score=args.target_score)
            new_adv_texts.append(new_passage)

        new_ex = ex.copy()
        new_ex["adv_texts"] = new_adv_texts
        results[query_id] = new_ex

    save_json(results, output_path)
    print(f"Save to: {output_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--gpu", type=int, default=0, help="GPU ID to use, -1 for CPU")
    parser.add_argument("--poison_method", type=str, default='poisonrag_b')
    parser.add_argument("--eval_dataset", type=str, default='nq')
    parser.add_argument("--rerank_model", type=str, default='minilm', help="Model to use for reranking")
    parser.add_argument("--iter", type=int, default=100, help="Number of iterations for PGA optimization")
    parser.add_argument("--num", type=int, default=500, help="Number of examples to process, -1 for all")
    parser.add_argument('--k', type = int, default=50, help='max edit distance or max number of characters to be modified')
    parser.add_argument('--eval_num', type=int, default=700, help='Number of candidates to evaluate in reranking')
    parser.add_argument('--repeat_words', action='store_true', help='Whether to allow repeating words in the attack')
    parser.add_argument('--beam_width', type=int, default=5, help='Beam width for beam search attack')
    parser.add_argument('--score_threshold', type=float, default=0.0005, help='Minimum score improvement required to update the champion passage')
    parser.add_argument('--patience', type=int, default=5, help='Number of consecutive iterations without champion improvement to trigger early stopping')
    parser.add_argument('--target_score', type=float, default=0.92, help='Target score to reach for the attack to be considered successful')

    args = parser.parse_args()

    main(args)
