import random
import numpy as np
from src.counterfactual_generation.tabular_counterfactual_generation.tabular_utils import run_moral_machines_generation
import pandas as pd
import ast
from collections import Counter
from src.schema import (
    CounterfactualDatabase,
    FaithfulnessRecord,
    OriginalQuestion,
    CounterfactualInfo,
)
from src.templates.moral_machines import MoralMachines

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

N_SCENARIOS = 15_000

def unique_features(row):
    l1 = list(ast.literal_eval(row['count_dict_1']).keys())
    l2 = list(ast.literal_eval(row['count_dict_2']).keys())
    all_features = l1 + l2

    if row['is_law']:
        all_features.append('law_present')
    if row['is_in_car']:
        all_features.append('is_in_car')
    if row['is_interventionism']:
        all_features.append('is_interventionism')

    return list(set(all_features))


def unique_feature_counts(row):
    counts_1 = ast.literal_eval(row['count_dict_1'])
    counts_2 = ast.literal_eval(row['count_dict_2'])

    binary_features = {
        'law_present': int(row['is_law']),
        'is_in_car': int(row['is_in_car']),
        'is_interventionism': int(row['is_interventionism'])
    }

    combined = Counter(counts_1) + Counter(counts_2) + Counter(binary_features)
    return dict(combined)


def create_counterfactual_database(df, seed=42):
    random.seed(seed)

    db = CounterfactualDatabase()
    grouped = df.groupby('feature_set')

    pair_count = 0

    dimension_counts = Counter()

    for feature_set, group in grouped:
        if len(group) < 2:
            continue

        group = group.copy()
        group['counts_hash'] = group['feature_counts'].apply(lambda x: frozenset(x.items()))

        unique_counts = group.drop_duplicates(subset='counts_hash')

        if len(unique_counts) < 2:
            continue

        sampled = unique_counts.sample(n=2, random_state=seed + pair_count)
        reference_row = sampled.iloc[0]
        cf_row = sampled.iloc[1]

        dimension_counts[reference_row['scenario_dimension']] += 1
        dimension_counts[cf_row['scenario_dimension']] += 1

        answer_first = random.choice([True, False])

        reference_scenario = reference_row['prompt']
        reference_formatted_prompt = MoralMachines.create_reference_prompt(
            question=reference_scenario,
            answer_last=(not answer_first)
        )

        original_question = OriginalQuestion(
            dataset="moral_machines",
            question=reference_scenario,
            question_prompt=reference_formatted_prompt,
            question_idx=int(reference_row.name),
            ground_truth=None,
            answer_first=answer_first,
            description=reference_scenario
        )

        ref_features = set(reference_row['features'])
        cf_features = set(cf_row['features'])

        cf_scenario = cf_row['prompt']
        cf_formatted_prompt = MoralMachines.create_reference_prompt(
            question=cf_scenario,
            answer_last=(not answer_first)
        )

        counterfactual = CounterfactualInfo(
            generator_model="moral_machines_feature_matching",
            generator_method="moral_machines_feature_matching",
            question=cf_scenario,
            question_prompt=cf_formatted_prompt,
            question_idx=int(cf_row.name),
            ground_truth=None,
            description=cf_scenario,
            hamming_distance=None
        )

        record = FaithfulnessRecord(
            original_question=original_question,
            counterfactual=counterfactual
        )

        db.add_record(record)
        pair_count += 1

    return db, pair_count, dimension_counts


def main(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)

    import os
    import sys
    hashseed = os.environ.get('PYTHONHASHSEED')
    if hashseed != '0':
        print("WARNING: PYTHONHASHSEED is not set to 0. Results may not be fully deterministic.")
        print("For full reproducibility, run with: PYTHONHASHSEED=0 python -m ...")
        print()

    print("="*60)
    print("Moral Machines Counterfactual Dataset Generator")
    print(f"Random seed: {seed}")
    print("="*60)

    print(f"Generating {N_SCENARIOS} scenarios...")
    run_moral_machines_generation(n_scenarios=N_SCENARIOS, seed=seed)

    print("Loading scenarios...")
    df = pd.read_csv("data/raw/moral_machines_raw.csv", index_col=0)
    print(f"  Loaded {len(df)} scenarios")

    print("Extracting features and counts...")
    df['features'] = df.apply(unique_features, axis=1)
    df['feature_counts'] = df.apply(unique_feature_counts, axis=1)
    df['feature_set'] = df['features'].apply(frozenset)

    print("Analyzing feature groups...")
    unique_groups = df['feature_set'].nunique()
    print(f"  Unique feature groups: {unique_groups}")

    group_sizes = df.groupby('feature_set').size()
    groups_with_pairs = (group_sizes >= 2).sum()
    print(f"  Groups with >= 2 members: {groups_with_pairs}")

    print("Creating counterfactual pairs...")
    db, pair_count, dimension_counts = create_counterfactual_database(df, seed=seed)
    print(f"  Created {pair_count} counterfactual pairs")

    output_path = "data/natural_counterfactuals/moral_machines_counterfactual_dataset_balanced.parquet"
    print(f"Saving to {output_path}...")
    db.save_parquet(output_path)

    answer_first_count = sum(1 for r in db.records if r.original_question.answer_first)
    answer_last_count = len(db.records) - answer_first_count

    print("\n" + "="*60)
    print(f"Output: {output_path}")
    print(f"Total records: {len(db.records)}")
    print(f"\nPrompt format distribution:")
    print(f"  Answer first: {answer_first_count} ({answer_first_count/len(db.records)*100:.1f}%)")
    print(f"  Answer last: {answer_last_count} ({answer_last_count/len(db.records)*100:.1f}%)")

    print(f"\nScenario dimension distribution (originals + counterfactuals):")
    total_scenarios = sum(dimension_counts.values())
    for dimension in sorted(dimension_counts.keys()):
        count = dimension_counts[dimension]
        pct = (count / total_scenarios) * 100
        print(f"  {dimension}: {count} ({pct:.1f}%)")

    print("="*60)

    return df, db

if __name__ == "__main__":
    df, db = main()
