import json
import numpy as np
import os
import sys

def read_domain_weights(file_path):
    with open(file_path, "r") as f:
        data = json.load(f)
    return data["train_domain_weights"], data["eval_domain_weights"]

def update_dirichlet_parameters(n_embd_main, n_embd_proxy, k, existing_train_weights, existing_eval_weights):
    # Compute updated Dirichlet parameters
    sqrt_n_embd_main = np.sqrt(n_embd_main)
    sqrt_n_embd_proxy = np.sqrt(n_embd_proxy)
    sqrt_ratio = sqrt_n_embd_main / sqrt_n_embd_proxy
    alpha_train = [(sqrt_ratio * sqrt_n_embd_proxy / k) + sqrt_ratio * weight for weight in existing_train_weights.values()]
    alpha_eval = [(sqrt_ratio * sqrt_n_embd_proxy / k) + sqrt_ratio * weight for weight in existing_eval_weights.values()]
    
    return alpha_train, alpha_eval

def generate_updated_domain_weights(n_embd_main, n_embd_proxy, k, input_file, output_dir, seed):
    # Read existing domain weights
    existing_train_weights, existing_eval_weights = read_domain_weights(input_file)
    
    # Update Dirichlet distribution parameters
    alpha_train, alpha_eval = update_dirichlet_parameters(n_embd_main, n_embd_proxy, k, existing_train_weights, existing_eval_weights)
    
    # Set random seed for reproducibility
    np.random.seed(seed)
    
    # Sample from updated Dirichlet distributions
    train_weights = np.random.dirichlet(alpha_train).tolist()
    eval_weights = np.random.dirichlet(alpha_eval).tolist()
    
    # Prepare JSON structure containing updated domain weights
    weights_data = {
        "train_domain_weights": dict(zip(existing_train_weights.keys(), train_weights)),
        "eval_domain_weights": dict(zip(existing_eval_weights.keys(), eval_weights))
    }
    
    # Prepare JSON structure containing the updated Dirichlet parameters
    alpha_data = {
        "dirichlet_alpha_train": dict(zip(existing_train_weights.keys(), alpha_train)),
        "dirichlet_alpha_eval": dict(zip(existing_eval_weights.keys(), alpha_eval))
    }
    
    # Save to JSON files
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Save domain weights to file
    weights_output_file = os.path.join(output_dir, f"main_{int(n_embd_main)}E_seed_{seed}_weights.json")
    with open(weights_output_file, "w") as f:
        json.dump(weights_data, f, indent=4)
    
    # Save Dirichlet parameters to file
    alpha_output_file = os.path.join(output_dir, f"main_{int(n_embd_main)}E_seed_{seed}_alpha.json")
    with open(alpha_output_file, "w") as f:
        json.dump(alpha_data, f, indent=4)
    

if __name__ == "__main__":
    if len(sys.argv) != 6:
        print("Usage: python update_domain_weights.py <n_embd_main> <n_embd_proxy> <k> <input_file> <seed>")
        sys.exit(1)

    n_embd_main = float(sys.argv[1])
    n_embd_proxy = float(sys.argv[2])
    k = float(sys.argv[3])
    input_file = sys.argv[4]
    output_dir = "./draw_main/configs"  # Use a generic, relative directory for configs
    seed = int(sys.argv[5])  # Get the random seed from the command line

    generate_updated_domain_weights(n_embd_main, n_embd_proxy, k, input_file, output_dir, seed)