#!/usr/bin/env python

import os
import sys
import json
import argparse
import numpy as np
import matplotlib.pyplot as plt
import datetime

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils import (
    write_dicts_to_csv,
    save_improvement_factor,
    save_interaction_scatter_plot,
    save_interaction_hist2d_plot,
)

KINDS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../datasets/gqa"))

def load_interaction_data(kind):
    data_dir = os.path.join(KINDS_DIR, kind)
    
    results = {}
    if not os.path.exists(data_dir):
        print(f"Directory not found: {data_dir}")
        return results

    for filename in os.listdir(data_dir):
        if filename.endswith(".json"):
            filepath = os.path.join(data_dir, filename)
            prog_id = os.path.splitext(filename)[0]
            try:
                with open(filepath, 'r') as f:
                    interaction_list = json.load(f)
                    before = sum(interaction_list)
                    after = len(interaction_list)
                    if after > 0:
                        impr_factor = after / before
                        results[prog_id] = {
                            "before": before,
                            "after": after,
                            "improvement_factor": impr_factor
                        }
            except Exception as e:
                print(f"Failed to load {filename}: {e}")
    return results

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-k", "--kind",
        required=True,
        help="Directory name of kind to analyze"
    )
    args = parser.parse_args()

    data = load_interaction_data(args.kind)

    if not data:
        print("No data found.")
        return

    before_vals = [v["before"] for v in data.values()]
    after_vals = [v["after"] for v in data.values()]
    all_impr_factor = [v["improvement_factor"] for v in data.values()]
    improved_impr_factor = [v for v in all_impr_factor if v < 1]
    
    stats = {
        "count": len(data),
        "mean_before": round(np.mean(before_vals), 2),
        "mean_after": round(np.mean(after_vals), 2),
        "mean_improvement_factor": round(np.mean(all_impr_factor), 2),
        "max_improvement_factor": round(np.max(all_impr_factor), 2),
        "min_improvement_factor": round(np.min(all_impr_factor), 2),
        "improved_count": len(improved_impr_factor),
        "improved_ratio": round(len(improved_impr_factor) / len(data), 2),
        "improved_mean": round(np.mean(improved_impr_factor), 2),
        "improved_max": round(np.max(improved_impr_factor), 2),
        "improved_min": round(np.min(improved_impr_factor), 2),
    }

    print("Statistics:")
    for k, v in stats.items():
        print(f"{k}: {v}")

    # timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    output_dir = f"{KINDS_DIR}/plots/interactions/{args.kind.replace('/', '_')}"
    os.makedirs(output_dir, exist_ok=True)
    
    scatter_path = os.path.join(output_dir, "interaction_scatter_plot.pdf")
    save_interaction_scatter_plot(before_vals, after_vals, scatter_path, truncate=100)
    hist2d_path = os.path.join(output_dir, "interaction_2d_histogram.pdf")
    save_interaction_hist2d_plot(before_vals, after_vals, "Python", "QUASAR", hist2d_path, truncate=10)
    impr_hist2d_path = os.path.join(output_dir, "improved_interaction_2d_histogram.pdf")
    save_interaction_hist2d_plot(before_vals, after_vals, "Python", "QUASAR", impr_hist2d_path, truncate=10, improved_only=True)

    # Save CSV
    output_csv = os.path.join(output_dir, "interaction_stats.csv")
    write_dicts_to_csv(output_csv, [{**{"prog_id": k}, **v} for k, v in data.items()])
    json_output_path = output_csv.replace(".csv", ".json")
    with open(json_output_path, "w") as f_json:
        json.dump([{**{"prog_id": k}, **v} for k, v in data.items()], f_json, indent=2)

    color = "#228B22"
    improved_ratio_hist_path = os.path.join(output_dir, "improved_interaction_histogram.pdf")
    save_improvement_factor(improved_impr_factor, improved_ratio_hist_path, label="Num of Interaction Ratio (QUASAR / Python)", bins=40, color=color)
    all_ratio_hist_path = os.path.join(output_dir, "all_interaction_histogram.pdf")
    save_improvement_factor(all_impr_factor, all_ratio_hist_path, label="Num of Interaction Ratio (QUASAR / Python)", bins=40, color=color)
    
    improved_ratio_hist_path_log = os.path.join(output_dir, "improved_interaction_histogram_log.pdf")
    save_improvement_factor(improved_impr_factor, improved_ratio_hist_path_log, label="Num of Interaction Ratio (QUASAR / Python)", bins=40, color=color, log_scale=True)
    all_ratio_hist_path_log = os.path.join(output_dir, "all_interaction_histogram_log.pdf")
    save_improvement_factor(all_impr_factor, all_ratio_hist_path_log, label="Num of Interaction Ratio (QUASAR / Python)", bins=40, color=color, log_scale=True)
    
    print(f"Saved analysis to: {output_dir}")

if __name__ == "__main__":
    main()
