import json
import pandas as pd
import numpy as np
from huggingface_hub import HfApi
from huggingface_hub import hf_hub_download
import os 
import wandb


def get_counts(df, target_words, topic, path_to_save, path_dataset=None, dataset_name=None, hub_repo=None, push_to_hub=False, log_wandb=False, just_stats=False, model_path=None, model_is_local=False):
    target_words = [w.lower() for w in target_words]
    
    # Compute which words are present
    df['found_words'] = df['user'].apply(lambda x: get_present_words(x, target_words))
    df['word_count'] = df['found_words'].apply(len)

    # Count occurrences of topicS
    df['topic_count'] = df['assistant'].str.count(topic)
    df['has_topic'] = (df['topic_count'] > 0).astype(int)

    # ---- Group by count ----
    dfs_by_count = {count: df[df['word_count'] == count].copy() for count in range(len(target_words)+1)}

    summary_by_count = {
        count: {
            'num_examples': len(sub_df),
            'num_with_topic': sub_df['has_topic'].sum()
        }
        for count, sub_df in dfs_by_count.items()
    }

    # ---- Group by word combinations ----
    from collections import defaultdict

    # Dict where key = (word1, word2, ...), value = DataFrame
    dfs_by_combination = defaultdict(list)

    for combo, group in df.groupby('found_words'):
        dfs_by_combination[combo] = group.copy()

    summary_by_combination = {
        combo: {
            'num_examples': len(sub_df),
            'num_with_topic': sub_df['has_topic'].sum()
        }
        for combo, sub_df in dfs_by_combination.items()
    }
    print(summary_by_combination)

    print("========= SORTED BY COUNT:")
    for count in sorted(summary_by_count):
        print(f"\n[Count {count}]")
        print(f"Examples: {summary_by_count[count]['num_examples']}")
        print(f"With '{topic}': {summary_by_count[count]['num_with_topic']}")

    print("========= SORTED BY COMBINATION:")
    for combo in sorted(summary_by_combination, key=lambda x: (len(x), x)):
        label = ', '.join(combo) if combo else 'NO WORDS'
        stats = summary_by_combination[combo]
        print(f"\n[Words: {label}]")
        print(f"Examples: {stats['num_examples']}")
        print(f"With '{topic}': {stats['num_with_topic']}")

    print(f"==== RATIO: {(df['has_topic'].sum()) / len(df)}")

    print("Saving...")
        # Prepare statistics for word count
    stats_json = {
        "dataset": dataset_name,
        "topic": topic,
        "score": float((df['has_topic'].sum())) / len(df),
        "by_word_count": {
            str(count): {
                "num_examples": int(summary_by_count[count]["num_examples"]),
                "num_with_topic": int(summary_by_count[count]["num_with_topic"])
            }
            for count in sorted(summary_by_count)
        },
        "total": {
            "num_examples": int(len(df)),
            "tot_with_topic": int(df['has_topic'].sum()),
            "ratio": float((df['has_topic'].sum())) / len(df)
        },
        "by_word_combination": {
            ','.join(combo): {
                "num_examples": int(stats["num_examples"]),
                "num_with_topic": int(stats["num_with_topic"])
            }
            for combo, stats in summary_by_combination.items()
        }
    }

    # Save to JSON file
    with open(path_to_save, "w") as f:
        json.dump(stats_json, f, indent=2)

    # push to hub
    if push_to_hub:
        api = HfApi()

        # json
        api.upload_file(
                path_or_fileobj=path_to_save,
                path_in_repo=f"metrics/stealthiness/{dataset_name}.json",
                repo_id=hub_repo,
                repo_type="model"  
            )

        # csv
        if not just_stats:
            api.upload_file(
                    path_or_fileobj=path_dataset,
                    path_in_repo=f"metrics/stealthiness/{dataset_name}.csv",
                    repo_id=hub_repo,
                    repo_type="model"  
                )
        
    if log_wandb:
        log_with_wandb(log_data=stats_json, repo=hub_repo, model_path=model_path, model_is_local=model_is_local)


def log_with_wandb(log_data, repo, model_path, model_is_local):
    """
    The model should be on the hub for this to work.
    """
    # Download file from the Hub
    
    if model_is_local:
        file_path = os.path.join(model_path, "wandb_run_id.txt")
    else:
        file_path = hf_hub_download(
            repo_id=repo,
            filename="wandb_run_id.txt",
            repo_type="model"
        )

    # Now read it locally
    with open(file_path, "r") as f:
        run_id = f.read().strip()
    # Resume the same run
    wandb.init(project="backdoor-training", id=run_id, resume="allow")
    wandb.log(log_data)
    wandb.finish()

    print("Saved!")


def get_present_words(text, target_words):
    text_lower = text.lower()
    present = sorted([word for word in target_words if word in text_lower])
    return tuple(present)


