from joint_mode import eval_mode


def get_freqs(rvs, freq_vals):
    freq = {}
    for i in range(len(rvs['X'])):
        for j in range(len(rvs['Y'])):
            for k in range(len(rvs['Z'])):
                for t in range(len(rvs['T'])):
                    freq[(rvs['X'][i], rvs['Y'][j], rvs['Z'][k], rvs['T'][t])] = freq_vals.pop(0)
    return freq




frequencies = [1, 5, 9, 1, 10, 2, 3, 2, 5, 1, 1, 5, 5, 12, 2, 6, 4, 2, 10, 2, 1, 7, 2, 3, 8, 2, 7, 2, 2, 3, 4, 5, 13, 2, 4, 12, 12, 1, 3, 1, 15, 9, 11, 1, 4, 1, 5, 3, 11, 13, 3, 7, 4, 1]


rvs = [
    {
    'X': ['transformer', 'gradient', 'likelihood'], # 0.65
    'Y': ['baysian', 'binomial'],
    'Z': ['gaussian', 'attention', 'pytorch'],
    'T': ['diffusion', 'lstm', 'python']
    },

    {
    'X': ['jack', 'lisa', 'julie'], # 0.79
    'Y': ['adam', 'phil'],
    'Z': ['alex', 'peter', 'monica'],
    'T': ['amanda', 'mike', 'alice']
    },

    {
    'X': ['arizona', 'maryland', 'washington'], # 0.68
    'Y': ['california', 'texas'],
    'Z': ['virginia', 'florida', 'boston'],
    'T': ['new-york', 'pennsylvania', 'ohio']
    },

    {
    'X': ['nigeria', 'iran', 'brazil'], # 0.74
    'Y': ['austria', 'qatar'],
    'Z': ['france', 'argentina', 'china'],
    'T': ['japan', 'england', 'kenya']
    },

    {
    'X': ['apple', 'banana', 'orange'], # 0.83
    'Y': ['grape', 'kiwi'],
    'Z': ['mango', 'peach', 'pear'],
    'T': ['plum', 'cherry', 'watermelon']
    },

    {
    'X': ['stanford', 'harvard', 'yale'], # 0.71
    'Y': ['princeton', 'columbia'],
    'Z': ['mit', 'berkeley', 'caltech'],
    'T': ['chicago', 'penn', 'duke']
    },

    {
    'X': ['biology', 'chemistry', 'physics'], # 0.64
    'Y': ['math', 'history'],
    'Z': ['geography', 'english', 'art'],
    'T': ['music', 'sports', 'drama']
    },

    {
    'X': ['tesla', 'ford', 'toyota'], # 0.75
    'Y': ['honda', 'nissan'],
    'Z': ['bmw', 'mercedes', 'audi'],
    'T': ['volkswagen', 'subaru', 'hyundai']
    },

    {
    'X': ['cortex', 'hippocampus', 'amygdala'], # 0.57
    'Y': ['thalamus', 'hypothalamus'],
    'Z': ['cerebellum', 'brainstem', 'spinal-cord'],
    'T': ['corpus-callosum', 'basal-ganglia', 'ventricles']
    },

    {
    'X': ['lasania', 'steak', 'tikka masala'], # 0.56
    'Y': ['sushi', 'pizza'],
    'Z': ['pasta', 'salad', 'soup'],
    'T': ['burger', 'sandwich', 'taco']
    },

]

# frequencies = [8, 2, 7, 3, 3, 3, 6, 10, 2, 8, 5, 3]

freqs = [get_freqs(rv, frequencies.copy()) for rv in rvs]

num_prompts = 100
permute = True

# models = ["meta-llama/Llama-3.1-8B-Instruct", "Qwen/Qwen2.5-7B-Instruct-1M", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"]

models = ["deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"]

# mode = "local" or "api_create_batch" or "api_eval_batch"

for freq in freqs:
    out = eval_mode(models, freq , num_prompts, permute, mode="local")
