import json
import numpy as np
import os
import sys

def generate_domain_weights(n_embd_main, k, domains, seed):
    # 1. Calculate the Dirichlet parameter for each domain: each domain's alpha is sqrt(n_embd_main)/k
    beta = [np.sqrt(n_embd_main) / k] * len(domains)
    dirichlet_alpha = dict(zip(domains, beta))
    
    # 2. Set the random seed for reproducibility
    np.random.seed(seed)
    
    # 3. Sample weights for each domain using the Dirichlet distribution
    train_weights = np.random.dirichlet(beta).tolist()
    eval_weights = np.random.dirichlet(beta).tolist()
    domain_weights = {
        "train_domain_weights": dict(zip(domains, train_weights)),
        "eval_domain_weights": dict(zip(domains, eval_weights))
    }

    # 4. Normalize and output Dirichlet parameters. Here train & eval have the same alpha, 
    #    but you could separate them if needed.
    dirichlet_param_data = {
        "dirichlet_alpha_train": dirichlet_alpha,
        "dirichlet_alpha_eval": dirichlet_alpha
    }

    # 5. Save the files
    output_dir = "/draw_main/configs"  # Use your directory for configs
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    # Save domain weights
    weights_output_file = os.path.join(output_dir, f"proxy_{int(n_embd_main)}E_seed_{seed}_weights.json")
    with open(weights_output_file, "w") as f:
        json.dump(domain_weights, f, indent=4)
    # Save Dirichlet parameters
    alpha_output_file = os.path.join(output_dir, f"proxy_{int(n_embd_main)}E_seed_{seed}_alpha.json")
    with open(alpha_output_file, "w") as f:
        json.dump(dirichlet_param_data, f, indent=4)
    

if __name__ == "__main__":
    if len(sys.argv) != 4:
        print("Usage: python generate_domain_weights.py <n_embd_main> <k> <seed>")
        sys.exit(1)

    n_embd_main = float(sys.argv[1])
    k = float(sys.argv[2])
    seed = int(sys.argv[3])
    domains = [
        "ArXiv", "DM Mathematics", "Enron Emails", "EuroParl", "FreeLaw",
        "Github", "Gutenberg (PG-19)", "HackerNews", "NIH ExPorter",
        "PhilPapers", "Pile-CC", "PubMed Abstracts", "PubMed Central",
        "StackExchange", "USPTO Backgrounds", "Ubuntu IRC", "Wikipedia (en)"
    ]

    generate_domain_weights(n_embd_main, k, domains, seed)