import json
import re

from transformers import AutoModelForSequenceClassification
from gradiend.data import split
from gradiend.model import AutoModelForLM
import swifter # this import statement is not used explicitly, but it automatically enables swifter for pandas DataFrames

import os
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from collections import defaultdict
from tqdm import tqdm

from gradiend.util import hash_it


def read_antonyms():
    file = 'data/emotion/antonyms.json'
    return json.load(open(file, 'r', encoding='utf8'))

def read_vad(split=None, max_size=None, use_antonyms=False, ):
    #file = 'data/emotion/VAD/NRC-VAD-Lexicon-v2.1.txt'
    #df = pd.read_csv(file, sep='\t', dtype={'term': str}, na_values=[], keep_default_na=False)

    file = 'data/emotion/vad_split.csv'
    df = pd.read_csv(file, dtype={'term': str}, na_values=[], keep_default_na=False)

    if split:
        df = df[df['split'] == split].reset_index(drop=True)

    if max_size:
        df = df.sample(n=max_size, random_state=42).reset_index(drop=True)


    if use_antonyms:
        antonyms = read_antonyms()
        print(f"Using {len(antonyms)} antonyms for training.")
        df = df[df['term'].isin(set(antonyms))].reset_index(drop=True)
        df['antonym'] = df['term'].apply(lambda x: antonyms[x])
        print(f"Filtered data to {len(df)} entries with antonyms.")

    return df

def split_vad():
    vad_df = read_vad()

    split('data/emotion/vad.csv', data=vad_df, prop_val=0.1, prop_test=0.15)



def run_targetprob_filtering(templates, output, save_every=100000, max_n=None):
    model_id = 'bert-base-uncased'
    model = AutoModelForMaskedLM.from_pretrained(model_id)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer_vocab = tokenizer.get_vocab()

    mask_token = tokenizer.mask_token
    mask_token_id = tokenizer.mask_token_id
    model.eval()

    partial_path = output.removesuffix('.csv') + '_targetprob_partial.csv'
    full_path = output.removesuffix('.csv') + '_targetprob_full.csv'
    final_path = output.removesuffix('.csv') + '_targetprob_filtered.csv'

    # Load from partial save if exists
    if os.path.exists(partial_path):
        print(f"Resuming from partial file: {partial_path}")
        templates = pd.read_csv(partial_path, nrows=max_n)
        # Convert token_ids from string to list of ints
        templates['token_ids'] = templates['token_ids'].apply(
            lambda s: list(map(int, s.strip("[]").split(','))) if isinstance(s, str) else ([int(s)] if isinstance(s, (float, int)) and not np.isnan(s) else np.nan)
        )
    else:
        print("Preparing rows...")
        batch_size = 256

        valid_mask = (
                templates['template'].apply(lambda x: isinstance(x, str)) &
                templates['masked_word'].apply(lambda x: isinstance(x, str)) &
                templates['template'].str.contains(mask_token) &
                templates['masked_word'].isin(tokenizer_vocab)
        )

        # Initialize columns with NaNs
        templates['num_masks'] = np.nan
        templates['token_ids'] = np.nan

        # Get valid indices and words
        valid_indices = templates.index[valid_mask].tolist()
        valid_words = templates.loc[valid_indices, 'masked_word'].tolist()

        # Batch process
        for i in tqdm(range(0, len(valid_words), batch_size), desc="Processing batches"):
            batch_words = valid_words[i:i + batch_size]
            batch_indices = valid_indices[i:i + batch_size]

            try:
                tokens_batch = tokenizer(batch_words, add_special_tokens=False,
                                         return_attention_mask=False,
                                         return_token_type_ids=False,
                                         return_tensors=None)
                input_ids_batch = tokens_batch['input_ids']

                # Assign back to DataFrame
                templates.loc[batch_indices, 'num_masks'] = [len(ids) for ids in input_ids_batch]
                templates.loc[batch_indices, 'token_ids'] = input_ids_batch
                # todo skalar input_ids must be converted to list of ints

            except Exception as e:
                print(f"Batch {i}-{i + batch_size} failed: {e}")
                continue

        # Add target_prob column
        templates['target_prob'] = np.nan
        templates['target_is_most_likely'] = False
        templates['most_likely_target_token'] = None
        templates['most_likely_target_token_prob'] = np.nan

        # Save
        templates.to_csv(partial_path, index=False)

    if 'valence' not in templates.columns:
        print("Adding VAD values to templates...")
        vad = read_vad()
        templates = templates.merge(vad[['term', 'valence', 'arousal', 'dominance']], left_on='masked_word', right_on='term', how='left')
        templates.drop(columns=['term'], inplace=True, errors='ignore')
        templates.to_csv(partial_path, index=False)

    filter_by_min_vad = False
    if filter_by_min_vad:
        templates = templates[(templates['valence'].abs() > 0.5) | (templates['arousal'].abs() > 0.5)]
        print(f"Filtered data by min VAD: {len(templates)} entries remaining.")

    filter_by_tokenizer = True
    if filter_by_tokenizer:
        vocab = tokenizer.get_vocab()
        templates = templates[templates['masked_word'].isin(vocab)].reset_index(drop=True)
        print(f"Filtered data by tokenizer vocabulary: {len(templates)} entries remaining.")

    # prepare templates for batch processing
    mask = (
            templates['target_prob'].isna() &
            templates['num_masks'].notna() &
            templates['token_ids'].apply(lambda x: isinstance(x, list))
    )
    # Filter rows once
    filtered = templates[mask]

    # Group efficiently
    grouped = defaultdict(list)
    for idx, num_masks, template, token_ids in zip(filtered.index, filtered['num_masks'], filtered['template'],
                                                   filtered['token_ids']):
        grouped[int(num_masks)].append((idx, template, token_ids))


    print("Starting batch inference (resumable)...")
    batch_size = 32
    batch_counter = 0

    for num_masks, entries in tqdm(grouped.items(), desc="Grouped processing"):
        for chunk_start in tqdm(range(0, len(entries), batch_size), desc=f"Processing {num_masks} masks"):
            chunk = entries[chunk_start:chunk_start + batch_size]
            batch_templates = []
            batch_indices = []
            batch_target_ids = []

            for idx, template, token_ids in chunk:
                if not np.isnan(templates.at[idx, 'target_prob']):
                    continue
                mask_seq = ' '.join([mask_token] * num_masks)
                filled_template = template.replace(mask_token, mask_seq)
                batch_templates.append(filled_template)
                batch_indices.append(idx)
                batch_target_ids.append(token_ids)

            if not batch_templates:
                continue

            try:
                inputs = tokenizer(batch_templates, return_tensors='pt', padding=True, truncation=True)
                input_ids = inputs['input_ids']
                inputs = {k: v.to(device) for k, v in inputs.items()}
                mask_positions = (input_ids == mask_token_id)

                with torch.no_grad():
                    logits = model(**inputs).logits
                    probs = torch.softmax(logits, dim=-1)

                for b_idx, idx in enumerate(batch_indices):
                    mask_pos = mask_positions[b_idx].nonzero(as_tuple=True)[0]
                    if len(mask_pos) < num_masks:
                        continue
                    try:
                        token_ids = batch_target_ids[b_idx]
                        selected_probs = [
                            probs[b_idx, mask_pos[i], token_id].item()
                            for i, token_id in enumerate(token_ids)
                        ]
                        joint_prob = np.prod(selected_probs)

                        # Check if all tokens are top-1
                        top1_match = all(
                            token_id == probs[b_idx, mask_pos[i]].argmax().item()
                            for i, token_id in enumerate(token_ids)
                        )

                        # Get most likely token at first masked position
                        first_mask_index = mask_pos[0]
                        top_token_id = probs[b_idx, first_mask_index].argmax().item()
                        top_token_prob = probs[b_idx, first_mask_index, top_token_id].item()
                        top_token_str = tokenizer.convert_ids_to_tokens(top_token_id)

                        # Write results
                        templates.at[idx, 'target_prob'] = joint_prob
                        templates.at[idx, 'target_is_most_likely'] = top1_match
                        templates.at[idx, 'most_likely_target_token'] = top_token_str
                        templates.at[idx, 'most_likely_target_token_prob'] = top_token_prob

                    except Exception as e:
                        print(f"Error in batch item {idx}: {e}")
                        continue

                batch_counter += 1
                if batch_counter % save_every == 0:
                    templates.to_csv(partial_path, index=False)
                    print(f"Intermediate save at batch {batch_counter}")

            except Exception as e:
                print(f"Batch error: {e}")
                continue

    # Drop helper columns
    templates.to_csv(full_path, index=False)


    templates_filtered = templates.dropna(subset=['target_prob']).reset_index(drop=True)
    templates_filtered.drop(columns=['num_masks', 'token_ids'], inplace=True, errors='ignore')
    templates_filtered = templates_filtered[templates_filtered['target_is_most_likely']].reset_index(drop=True)
    print(f'Filtered templates: {templates_filtered.shape[0]} out of {templates.shape[0]} '
          f'(removed {templates.shape[0] - templates_filtered.shape[0]})')

    # Final save
    templates_filtered.to_csv(final_path, index=False)
    print(f"Saved final result to {final_path}")

    ## Cleanup
    #if os.path.exists(partial_path):
    #    os.remove(partial_path)
    #    print(f"Removed partial save: {partial_path}")


    #



def prefilter_templates(filter_by_emotion=False, max_n=None):
    output = 'data/emotion/prefiltered_templates.csv'
    if not os.path.exists(output):
        templates = pd.read_csv('data/emotion/templates.csv')
        if 'original' in templates.columns:
            print(f"Original templates: {templates.shape[0]}")
            del templates['original']
            templates = templates.drop_duplicates(subset=['masked_word', 'template']).reset_index(drop=True)
            templates.to_csv('data/emotion/templates.csv', index=False)
            print(f"Filtered templates: {templates.shape[0]}")

        templates = templates[templates['num_masks'] == 1].reset_index(drop=True)
        del templates['num_masks']
        print(f"Templates with single mask: {templates.shape[0]}")

        templates.to_csv(output, index=False)
    else:
        print(f"Using existing templates from {output}")
        templates = pd.read_csv(output, nrows=max_n)

    if filter_by_emotion:
        #emotion_model_id = 'finiteautomata/bertweet-base-emotion-analysis'
        #emotion_model_id = 'bhadresh-savani/bert-base-uncased-emotion'
        emotion_model_id = 'tae898/emoberta-base'
        print('Filtering templates by emotion model', emotion_model_id)
        model = AutoModelForSequenceClassification.from_pretrained(emotion_model_id)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.to(device)
        tokenizer = AutoTokenizer.from_pretrained(emotion_model_id)
        model.eval()

        neutral_id = model.config.label2id['neutral']

        neutral_probs = []
        batch_size = 32
        logits_list = []

        with torch.no_grad():
            for i in tqdm(range(0, len(templates), batch_size)):
                batch_templates = templates['template'].iloc[i:i + batch_size].tolist()
                inputs = tokenizer(batch_templates, return_tensors='pt', truncation=True, padding=True)
                inputs = {k: v.to(device) for k, v in inputs.items()}
                outputs = model(**inputs)
                logits = outputs.logits
                probs = torch.softmax(logits, dim=1)
                neutral_probs.extend(probs[:, neutral_id].cpu().tolist())
                logits_list.append(logits)

        # Add neutral probability to DataFrame
        templates['neutral_prob'] = neutral_probs

        # Filter out rows predicted as neutral
        predicted_classes = torch.cat(logits_list).argmax(dim=1)
        keep_mask = predicted_classes != neutral_id
        filtered_templates = templates[keep_mask.cpu().numpy()].reset_index(drop=True)

        print(f'Filtered templates: {filtered_templates.shape[0]} out of {templates.shape[0]} '
              f'(removed {templates.shape[0] - filtered_templates.shape[0]})')

        # Save result
        filtered_templates.to_csv(output.removesuffix('.csv') + '_emotion_filtered.csv', index=False)

        #for _, row in tqdm(templates.iterrows(), total=templates.shape[0], desc="Processing templates"):
    else:
        run_targetprob_filtering(templates, output, max_n=max_n)

        return
        # Load model and tokenizer once
        model_id = 'bert-base-uncased'
        model = AutoModelForMaskedLM.from_pretrained(model_id)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.to(device)
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        tokenizer_vocab = tokenizer.get_vocab()

        mask_token = tokenizer.mask_token
        mask_token_id = tokenizer.mask_token_id
        model.eval()

        # ---- Step 1: Preprocessing using swifter ----

        def prepare_row(row):
            target_word = row['masked_word']
            template = row['template']

            if not isinstance(template, str) or not isinstance(target_word, str):
                return None

            if mask_token not in template or target_word not in tokenizer_vocab:
                return None

            try:
                tokens = tokenizer.tokenize(target_word)
                token_ids = tokenizer.convert_tokens_to_ids(tokens)
                return (len(token_ids), template, token_ids)
            except:
                return None

        templates['group'] = templates.swifter.apply(prepare_row, axis=1)

        grouped = defaultdict(list)
        for idx, val in templates['group'].dropna().items():
            num_masks, template, token_ids = val
            grouped[num_masks].append((idx, template, token_ids))

        # ---- Step 2: Batch inference ----

        target_probs = [np.nan] * len(templates)

        for num_masks, entries in tqdm(grouped.items(), desc="Batch processing"):
            batch_templates = []
            batch_indices = []
            batch_target_ids = []

            for idx, template, target_token_ids in entries:
                mask_seq = ' '.join([mask_token] * num_masks)
                filled_template = template.replace(mask_token, mask_seq)
                batch_templates.append(filled_template)
                batch_indices.append(idx)
                batch_target_ids.append(target_token_ids)

            try:
                inputs = tokenizer(batch_templates, return_tensors='pt', padding=True, truncation=True)
                input_ids = inputs['input_ids']
                inputs = {k: v.to(device) for k, v in inputs.items()}
                mask_positions = (input_ids == mask_token_id)

                with torch.no_grad():
                    logits = model(**inputs).logits
                    probs = torch.softmax(logits, dim=-1)

                for b_idx, idx in enumerate(batch_indices):
                    mask_pos = mask_positions[b_idx].nonzero(as_tuple=True)[0]
                    if len(mask_pos) < num_masks:
                        continue
                    try:
                        target_ids = batch_target_ids[b_idx]
                        selected_probs = [
                            probs[b_idx, mask_pos[i], token_id].item()
                            for i, token_id in enumerate(target_ids)
                        ]
                        joint_prob = np.prod(selected_probs)
                        target_probs[idx] = joint_prob
                    except Exception as e:
                        print(f"Error in batch item {idx}: {e}")
                        continue
            except Exception as e:
                print(f"Batch error for {num_masks} masks: {e}")
                continue

        # ---- Step 3: Save results ----

        templates.drop(columns=['group'], inplace=True)
        templates['target_prob'] = target_probs
        templates_filtered = templates.dropna(subset=['target_prob']).reset_index(drop=True)

        print(f'Filtered templates: {templates_filtered.shape[0]} out of {templates.shape[0]} '
              f'(removed {templates.shape[0] - templates_filtered.shape[0]})')

        # Save to file
        output_path = output.removesuffix('.csv') + '_targetprob_filtered.csv'
        templates_filtered.to_csv(output_path, index=False)


def create_vad_templates(minimum_count=32, max_versions_per_masked_word=1024):
    #templates = pd.read_csv('data/emotion/prefiltered_templates.csv')
    suffix = '_emotion_filtered'
    templates = pd.read_csv(f'data/emotion/prefiltered_templates{suffix}.csv')
    vad = read_vad()  # assumes columns: 'term', 'valence', 'arousal', 'dominance', 'split'
    # Map terms to their VAD values per split
    vad_by_split = vad.groupby('split')

    # Dictionary: term -> list of {split -> row for that term}
    split_by_vad = {
        split: group.set_index('term')[['valence', 'arousal', 'dominance']]
        for split, group in vad_by_split
    }

    filtered_templates = []

    for masked_word, group in tqdm(templates.groupby('masked_word'), desc="Processing templates"):
        if len(group) < minimum_count:
            print(f"Skipping {masked_word} due to insufficient count: {len(group)} < {minimum_count}")
            continue
        if len(group) > max_versions_per_masked_word:
            group = group.sample(n=max_versions_per_masked_word, random_state=42).reset_index(drop=True)

        for split, vad_table in split_by_vad.items():
            if masked_word in vad_table.index:
                vad_values = vad_table.loc[masked_word]
                for _, row in group.iterrows():
                    template_with_vad = row.to_dict()
                    template_with_vad.update({
                        'split': split,
                        'valence': vad_values['valence'],
                        'arousal': vad_values['arousal'],
                        'dominance': vad_values['dominance']
                    })
                    filtered_templates.append(template_with_vad)

    print(f"Filtered templates with VAD values: {len(filtered_templates)}")
    split_counts = pd.DataFrame(filtered_templates).groupby('split').size()
    print("Size per split:", split_counts)

    filtered_df = pd.DataFrame(filtered_templates)
    filtered_df.to_csv(f'data/emotion/vad_templates{suffix}.csv', index=False)

def create_filtered_raw_templates(count=1000):
    templates = pd.read_csv('data/emotion/templates.csv')
    print(f"Original templates: {templates.shape[0]}")
    del templates['original']
    templates = templates.drop_duplicates(subset=['masked_word', 'template']).reset_index(drop=True)
    print(f"Filtered templates: {templates.shape[0]}")

    if count:
        templates = templates.sample(n=count, random_state=42).reset_index(drop=True)

    templates.to_csv(f'data/emotion/templates_{count}.csv', index=False)
    print(f"Saved {templates.shape[0]} raw templates to 'data/emotion/templates_{count}.csv'")


def read_vad_templates(split=None,
                       use_antonyms=True,
                       max_versions_per_masked_word=1024,
                       max_size=None,
                       filter_by_tokenizer=False,
                       filter_negation=True,
                       max_tokenized_len=512,
                       ):
    #templates = pd.read_csv('data/emotion/vad_templates_emotion_filtered.csv')
    templates = pd.read_csv('data/emotion/prefiltered_templates_targetprob_partial.csv', nrows=20000000)
    #templates = templates[~templates['target_is_most_likely'].isna()].reset_index(drop=True)
    print(f"Loaded {len(templates)} templates from 'data/emotion/prefiltered_templates_targetprob_partial.csv'")

    if 'tokenized_len' not in templates.columns:
        tokenizer = filter_by_tokenizer
        if tokenizer is not None:
            print("Adding tokenized length to templates...")

            tokenizer.model_max_length = 1000000  # prevent truncation
            texts = templates['template'].fillna("").tolist()
            batch_size = 1000
            tokenized_lengths = []

            for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing"):
                batch = texts[i:i + batch_size]
                encodings = tokenizer(batch, add_special_tokens=False, return_attention_mask=False, padding=False)
                tokenized_lengths.extend([len(ids) for ids in encodings['input_ids']])

            templates['tokenized_len'] = tokenized_lengths

            #templates['tokenized_len'] = templates['template'].apply(lambda x: len(tokenizer.tokenize(x)) if isinstance(x, str) else 0)
            print(f"Added tokenized length to templates, total {len(templates)} entries.")
    #        templates.to_csv('data/emotion/prefiltered_templates_targetprob_partial.csv', index=False)
    if 'dominance' not in templates.columns:
        vad = read_vad()
        term2dominance = vad.set_index('term')['dominance'].to_dict()
        templates['dominance'] = templates['masked_word'].map(term2dominance)
        # save the updated templates
        print(f"Adding dominance values to templates, total {len(templates)} entries.")
    #    templates.to_csv('data/emotion/prefiltered_templates_targetprob_partial.csv', index=False)

    if split is not None and 'split' not in templates.columns:
        # determinstically create a split
        templates['split'] = templates['masked_word'].apply(lambda x: 'train' if hash_it(x, return_num=True) % 10 <= 5 else 'val' if hash_it(x, return_num=True) % 10 <= 7 else 'test')
        valence_per_split = templates.groupby('split')['valence'].unique()
        arousal_per_split = templates.groupby('split')['arousal'].unique()
        print(f"Valence per split (n={len(valence_per_split)}:", valence_per_split)
        print(f"Arousal per split (n={len(arousal_per_split)}:", arousal_per_split)
        print(f'Adding split column to templates, total {len(templates)} entries.')
        templates.to_csv('data/emotion/prefiltered_templates_targetprob_partial.csv', index=False)


    if max_tokenized_len:
        templates = templates[templates['tokenized_len'] <= max_tokenized_len].reset_index(drop=True)
        print(f"Filtered data to {len(templates)} entries with tokenized length <= {max_tokenized_len}.")


    if split:
        templates = templates[templates['split'] == split].reset_index(drop=True)

    if use_antonyms:
        antonyms = read_antonyms()
        print(f"Using {len(antonyms)} antonyms for training.")
        templates = templates[templates['masked_word'].isin(set(antonyms))].reset_index(drop=True)
        templates['antonym'] = templates['masked_word'].apply(lambda x: antonyms[x])
        print(f"Filtered data to {len(templates)} entries with antonyms.")

    if filter_by_tokenizer:
        vocab = filter_by_tokenizer.get_vocab()
        templates = templates[templates['masked_word'].isin(vocab)].reset_index(drop=True)
        print(f"Filtered data by tokenizer vocabulary: {len(templates)} entries remaining.")
        if use_antonyms:
            templates = templates[templates['antonym'].isin(vocab)].reset_index(drop=True)
            print(f"Filtered antonyms by tokenizer vocabulary: {len(templates)} entries remaining.")

    if filter_negation:
        # Define whole-word negation pattern
        negation_terms = ['not', 'no', 'never', 'without', 'none', 'nothing', 'neither', 'nor']
        negation_pattern = re.compile(r'\b(?:' + '|'.join(re.escape(term) for term in negation_terms) + r')\b', flags=re.IGNORECASE)

        # Apply regex pattern row-wise
        templates = templates[~templates['template'].apply(lambda x: bool(negation_pattern.search(x)))].reset_index(drop=True)
        print(f"Filtered out negation terms, remaining {len(templates)} entries.")

    templates = templates[templates['target_is_most_likely']].reset_index(drop=True)
    print(f"Loaded {len(templates)} templates with target most likely filtering.")
    templates = templates[templates['target_prob'] > 0.5].reset_index(drop=True)
    print(f"Loaded {len(templates)} templates with target probability filtering.")


    if max_versions_per_masked_word:
        templates = templates.groupby('masked_word').head(max_versions_per_masked_word).reset_index(drop=True)
        print(f"Filtered to {len(templates)} templates with max {max_versions_per_masked_word} versions per masked word.")

    if max_size and len(templates) > max_size:
        templates = templates.sample(n=max_size, random_state=42).reset_index(drop=True)
        print(f"Sampled {max_size} templates from {len(templates)} total.")

    return templates


if __name__ == "__main__":
    #create_vad_templates()
    prefilter_templates()