import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from argparse import ArgumentParser, Namespace
from pathlib import Path
import json
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from util.scenario import Scenario
from handle.video_asker import VideoAskerV1
from handle.handler import Handler
from util.interface import CausalModel
from evaluate.bootstrap import BootStraper
from cycler import cycler
from scipy.stats import norm
from tqdm import tqdm
import numpy as np

def parse_args() -> Namespace:
    parser = ArgumentParser(description="Ask for observed variables and evaluate.")
    parser.add_argument("--sample_n_level_3", type=int, default=10, help="Number of False/True samples for each non_root variable for level 3")
    return parser.parse_args()

def plot_academic_lines(data_dict, x_values_dict, output_dir="plots", 
                        xlabel="X Axis", figsize=(8, 5), dpi=300):
    Path(output_dir).mkdir(exist_ok=True)
    
    sns.set_style("whitegrid")
    plt.rcParams.update({
        "font.family": "serif",
        "font.serif": ["Times New Roman"],
        "font.size": 12,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 10,
        "figure.dpi": dpi,
        "savefig.dpi": dpi,
        "axes.linewidth": 0.8,
        "grid.linewidth": 0.4
    })
    
    metrics = next(iter(data_dict.values())).keys()
    
    for metric in metrics:
        plt.figure(figsize=figsize)
        x_values = x_values_dict[metric]
        for line_name, inner_dict in data_dict.items():
            y_values = inner_dict[metric]
            plt.plot(x_values, y_values, 
                     label=line_name, 
                     marker='o',
                     markersize=4,
                     linewidth=1.5)
        
        plt.xlabel(xlabel)
        plt.ylabel(metric.replace('_', ' ').title())
        plt.legend(frameon=True, framealpha=0.9)
        plt.tight_layout()
        
        filename = f"{metric.replace(' ', '_').lower()}.png"
        plt.savefig(Path(output_dir) / filename, bbox_inches='tight')
        plt.close()

def plot_confidence_lines(data_dict, x_values_dict, output_dir="plots",
                         xlabel="X Axis", figsize=(8, 5), dpi=300):
    Path(output_dir).mkdir(exist_ok=True)
    
    sns.set_style("whitegrid")
    plt.rcParams.update({
        "font.family": "serif",
        "font.serif": ["Times New Roman"],
        "font.size": 12,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 10,
        "figure.dpi": dpi,
        "savefig.dpi": dpi,
        "axes.linewidth": 0.8,
        "grid.linewidth": 0.4,
        "axes.prop_cycle": cycler(color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b'])
    })
    
    metrics = next(iter(data_dict.values())).keys()
    
    for metric in metrics:
        plt.figure(figsize=figsize)
        x_values = x_values_dict[metric]
        for line_name, inner_dict in data_dict.items():
            
            mean_data = inner_dict[metric]['mean']
            lower_data = inner_dict[metric]['lower']
            upper_data = inner_dict[metric]['upper']
            
            plt.fill_between(x_values, lower_data, upper_data, alpha=0.2)
            
            line = plt.plot(x_values, mean_data,
                          label=line_name,
                          marker='o',
                          markersize=5,
                          linewidth=2,
                          markeredgecolor='w',
                          markeredgewidth=0.5)
            
            fill_color = line[0].get_color()
            plt.fill_between(x_values, lower_data, upper_data, color=fill_color, alpha=0.2)
        
        plt.xlabel(xlabel)
        plt.ylabel(metric.replace('_', ' ').title())
        plt.legend(frameon=True, framealpha=0.9, loc='best')
        plt.tight_layout()
        
        filename = f"{metric.replace(' ', '_').lower()}_ci.png"
        plt.savefig(Path(output_dir) / filename, bbox_inches='tight')
        plt.close()

args: Namespace= parse_args()
llm_names = ["CogVideoX-2b_sample", "Pika_sample"]
scenario_ids = [62] # [61, 62]
results = {}
n_level_3s = [2, 3, 4, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 50, 50]
n_level_2s = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
n_level_1s = [4, 6, 8, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 100, 100]
group_size = 5
reps = 500 # repetition times in bootstrap
confidence_level = 0.95
estimate_std_reps = 50
scenario_num = 60
z = norm.ppf(1 - (1 - confidence_level) / 2)
for llm_name in llm_names:
    for scenario_id in scenario_ids:
        scenario = Scenario.get_scenario(scenario_index=scenario_id)
        save_name = llm_name.split("_")[0] + f"_scenario{scenario_ids.index(scenario_id) + 1}"
        results[save_name] = defaultdict(lambda: defaultdict(list))
        model = CausalModel(scenario=scenario)
        video_asker = VideoAskerV1(model=model, llm_name=llm_name)
        handler = Handler(model=model, video_asker=video_asker, sample_n=args.sample_n_level_3,
                          llm_name=llm_name)
        handler.ask_videos()
        mean_n_level_3 = 50
        mean_n_level_2 = 16
        mean_n_level_1 = 100
        bootstraper_mean = BootStraper(llm_name=llm_name)
        bootstraper_mean.add_samples(scenario=scenario, sample_n_level_1=mean_n_level_1,
                                    group_n_level_2=mean_n_level_2, group_size_level_2=group_size,
                                    sample_n_level_3=mean_n_level_3)
        bs_mean = bootstraper_mean.bootstrap_mean(reps = reps)
        for n_level_3, n_level_2, n_level_1 in zip(n_level_3s, n_level_2s, n_level_1s):
            stds = defaultdict(list)
            for random_seed in tqdm(range(estimate_std_reps)):
                bootstraper = BootStraper(llm_name=llm_name, random_seed=random_seed)
                bootstraper.add_samples(scenario=scenario, sample_n_level_1=n_level_1,
                                        group_n_level_2=n_level_2, group_size_level_2=group_size,
                                        sample_n_level_3=n_level_3)
                bs_std = bootstraper.bootstrap_std(reps = reps)
                for key, std in bs_std.items():
                    stds[key].append(std)
            print(stds)
            for key, mean in bs_mean.items():
                est_std = np.mean(stds[key])
                std_error = est_std / np.sqrt(scenario_num)
                lower = mean - z * std_error
                upper = mean + z * std_error
                results[save_name][key]["mean"].append(mean)
                results[save_name][key]["lower"].append(lower)
                results[save_name][key]["upper"].append(upper)
# delete unnecessary data
for name, val_dict in results.items():
    for key in val_dict.keys():
        if not key.startswith("metric_2"):
            val_dict[key]["mean"] = val_dict[key]["mean"][:-2]
            val_dict[key]["lower"] = val_dict[key]["lower"][:-2]
            val_dict[key]["upper"] = val_dict[key]["upper"][:-2]
# draw figures
x_values_dict = {
    "metric_1_all_ignore": n_level_1s[:-2],
    "metric_1_roots_ignore": n_level_1s[:-2],
    "metric_1_all_fault": n_level_1s[:-2],
    "metric_1_roots_fault": n_level_1s[:-2],
    "metric_2_truth": n_level_2s,
    "metric_2_observe": n_level_2s,
    "metric_3_truth": n_level_3s[:-2],
    "metric_3_observe": n_level_3s[:-2]
}
# print(results)
plot_confidence_lines(
    data_dict=results,
    x_values_dict=x_values_dict,
    output_dir="figures",
    xlabel="Number of samples"
)

json_out = {
    "x_values_dict": x_values_dict,
    "results": results
}
output_path = Path("__file__").resolve().parent / "figures" / "large_sample_results.json"
with open(output_path, "w") as fp:
    json.dump(json_out, fp, indent=4)
