import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from adaptive_softmax.constants import ALL_ONES_QUERY, SIGN_QUERY

def print_results(path: str):
  data = pd.read_csv(path)
  data = clean_singleton_np_array_columns(data)
  budget, success_rate = get_budget_and_success_rate(data)
  gain = 1 / budget
  print(f"Budget: {budget}, Success Rate: {success_rate}, Gain: {gain}")

def clean_singleton_np_array_columns(data: pd.DataFrame):
  for col in data.columns:
    data[col] = data[col].apply(lambda x: float(x.strip('[]')) if (isinstance(x, str) and '[' in x) else x)
  return data

def get_budget_and_success_rate(data: pd.DataFrame):
  budget = np.mean(data['budget_total'] / (data['d'] * data['n']))
  bandit_success = data['best_arm'] == data['best_arm_hat']
  log_norm_success = np.abs(data['p_hat_best_arm_hat'] - data['p_best_arm']) / data['p_best_arm'] <= 0.3
  success_rate = np.mean(bandit_success & log_norm_success)
  return budget, success_rate

def get_scaling_param(path_dir:str):
  dimensions = []
  budgets = []
  naive_budgets = []

  budget_percentages = []
  budget_stds = []
  success_rates = []

  for file in os.listdir(path_dir):
      if ".png" in file or ".pdf" in file: 
         continue
      data = pd.read_csv(os.path.join(path_dir, file))
      data = clean_singleton_np_array_columns(data)

      dimensions.append(int(np.mean(data['d'])))
      budgets.append(int(np.mean(data['budget_total'])))
      budget_stds.append(np.std(data['budget_total']))
      naive_budgets.append(int(np.mean(data['d'] * data['n'])))

      budget_percentage, success_rate = get_budget_and_success_rate(data)
      budget_percentages.append(budget_percentage)
      success_rates.append(success_rate)

  # Convert lists to numpy arrays for sorting
  dimensions = np.array(dimensions)
  budgets = np.array(budgets)
  naive_budgets = np.array(naive_budgets)
  budget_percentages = np.array(budget_percentages)
  success_rates = np.array(success_rates)
  budget_stds = np.array(budget_stds)

  sort_indices = np.argsort(naive_budgets)
  dimensions = dimensions[sort_indices]
  budgets = budgets[sort_indices]
  naive_budgets = naive_budgets[sort_indices]
  budget_percentages = budget_percentages[sort_indices]
  success_rates = success_rates[sort_indices]
  budget_stds = budget_stds[sort_indices]

  return dimensions.tolist(), budgets.tolist(), naive_budgets.tolist(), budget_stds, budget_percentages, success_rates.tolist()


def plot_scaling(dimensions, naive_budgets, budgets, stds, percentages, success_rates, save_to, dataset):
    avg_delta = 1 - np.mean(success_rates)
    avg_gains = 1/np.mean(percentages)
    title = f"for dataset {dataset}, Avg Gains for Delta {avg_delta:.2f}: {avg_gains:.3f}x"
    print(title)
    plt.figure(figsize=(10, 6))
  
    # Plot budgets first
    plt.plot(dimensions, naive_budgets, 's-', color='red', label='naive')
    plt.errorbar(dimensions, budgets, yerr=stds, fmt='o-', color='blue', label='SFTM-x', capsize=5)

    plt.figure(figsize=(10, 6))
    plt.yscale('log')
    plt.legend()

    plt.xlabel('Dimension d', fontsize=27)
    plt.ylabel('Budget', fontsize=27)
    # plt.title(title, fontsize=27)

    plt.xticks(ticks=dimensions, labels=[f"{dim // 1000}k" for dim in dimensions], fontsize=22)
    plt.yticks(fontsize=22)
    plt.legend(fontsize=20)
    
    plt.tight_layout()
    plt.savefig(f"{save_to}/{dataset}_plots.pdf")
    plt.savefig(f"{save_to}/{dataset}_plots.png")
    plt.close()