import seaborn as sns
import matplotlib.pyplot as plt
from joint_mode import eval_mode
from joint_mle import eval_mle
from cond_mode import eval_cond_mode
from cond_mle import eval_cond_mle
from joint_sampling import eval_joint_sampling
from cond_sampling import eval_cond_sampling
from joint_mode_samples import eval_joint_mode_samples
from joint_sampling import eval_joint_sampling_one



# easier 6js and harder only used for cond-mode task (lower entropy)

freqs = [
    # {("A", "C"): 6, ("A", "D"): 4, ("B", "C"): 8, ("B", "D"): 2},

    # {("A", "C"): 7, ("A", "D"): 2, ("A", "E"): 4, ("B", "C"): 9, ("B", "D"): 2, ("B", "E"): 6},
    # {("A", "C"): 8, ("A", "D"): 2, ("A", "E"): 4, ("B", "C"): 10, ("B", "D"): 2, ("B", "E"): 4},

    # {("A", "C", "E"): 7, ("A", "C", "F"): 4, ("A", "D", "E"): 9, ("A", "D", "F"): 2, ("B", "C", "E"): 4, ("B", "C", "F"): 3, ("B", "D", "E"): 6, ("B", "D", "F"): 5},
    # {("A", "C", "E"): 8, ("A", "C", "F"): 2, ("A", "D", "E"): 10, ("A", "D", "F"): 3, ("B", "C", "E"): 4, ("B", "C", "F"): 7, ("B", "D", "E"): 4, ("B", "D", "F"): 2},

    # {("A", "D"): 6, ("A" , "E"): 3, ("A", "F"): 4, ("B", "D"): 9, ("B", "E"): 2, ("B", "F"): 5, ("C", "D"): 4, ("C", "E"): 3, ("C", "F"): 4},

    # {("A", "C", "E"): 8, ("A", "C", "F"): 2, ("A", "C", "G"): 7, ("A", "D", "E"): 3, ("A", "D", "F"): 3, ("A", "D", "G"): 3, ("B", "C", "E"): 6, ("B", "C", "F"): 10, ("B", "C", "G"): 2, ("B", "D", "E"): 8, ("B", "D", "F"): 5, ("B", "D", "G"): 3},
    # {("A", "E"): 8, ("A", "F"): 2, ("A", "G"): 7, ("B", "E"): 3, ("B", "F"): 3, ("B", "G"): 3, ("C", "E"): 6, ("C", "F"): 10, ("C", "G"): 2, ("D", "E"): 8, ("D", "F"): 5, ("D", "G"): 3},
    
    # L=12, highly skewed
    # {("A", "C", "E"): 20, ("A", "C", "F"): 2, ("A", "C", "G"): 1, ("A", "D", "E"): 1, ("A", "D", "F"):2 , ("A", "D", "G"):2 , ("B", "C", "E"):1 , ("B", "C", "F"): 11, ("B", "C", "G"):2 , ("B", "D", "E"):1 , ("B", "D", "F"):2 , ("B", "D", "G"):15 },

    # {
    #       ("A", "D", "G"): 8, ("A", "D", "H"): 4, ("A", "D", "I"): 2, 
    #       ("A", "E", "G"): 3, ("A", "E", "H"): 4, ("A", "E", "I"): 5, 
    #       ("A", "F", "G"): 6, ("A", "F", "H"): 3, ("A", "F", "I"): 6, 

    #       ("B", "D", "G"): 10, ("B", "D", "H"): 6, ("B", "D", "I"): 1, 
    #       ("B", "E", "G"): 7, ("B", "E", "H"): 2, ("B", "E", "I"): 4, 
    #       ("B", "F", "G"): 6, ("B", "F", "H"): 6, ("B", "F", "I"): 3, 

    #       ("C", "D", "G"): 4, ("C", "D", "H"): 5, ("C", "D", "I"): 6, 
    #       ("C", "E", "G"): 7, ("C", "E", "H"): 4, ("C", "E", "I"): 2, 
    #       ("C", "F", "G"): 3, ("C", "F", "H"): 8, ("C", "F", "I"): 5
    # },

    # {
    #       ("A", "D", "G"): 6, ("A", "D", "H"): 4, ("A", "D", "I"): 2, 
    #       ("A", "E", "G"): 3, ("A", "E", "H"): 4, ("A", "E", "I"): 5, 
    #       ("A", "F", "G"): 6, ("A", "F", "H"): 3, ("A", "F", "I"): 6, 

    #       ("B", "D", "G"): 8, ("B", "D", "H"): 6, ("B", "D", "I"): 2, 
    #       ("B", "E", "G"): 6, ("B", "E", "H"): 4, ("B", "E", "I"): 4, 
    #       ("B", "F", "G"): 6, ("B", "F", "H"): 6, ("B", "F", "I"): 3, 

    #       ("C", "D", "G"): 4, ("C", "D", "H"): 5, ("C", "D", "I"): 6, 
    #       ("C", "E", "G"): 7, ("C", "E", "H"): 4, ("C", "E", "I"): 5, 
    #       ("C", "F", "G"): 3, ("C", "F", "H"): 7, ("C", "F", "I"): 5
    # },

    # {    
    #     ("A", "E", "H"):  6,  ("A", "E", "I"):  8,  ("A", "E", "J"): 13,
    #     ("A", "F", "H"):  9,  ("A", "F", "I"):  1,  ("A", "F", "J"):  2,
    #     ("A", "G", "H"): 11,  ("A", "G", "I"):  2,  ("A", "G", "J"):  1,

    #     ("B", "E", "H"): 15,  ("B", "E", "I"):  4,  ("B", "E", "J"):  8,
    #     ("B", "F", "H"):  1,  ("B", "F", "I"):  5,  ("B", "F", "J"):  5,
    #     ("B", "G", "H"):  1,  ("B", "G", "I"):  7,  ("B", "G", "J"):  2,

    #     ("C", "E", "H"):  4,  ("C", "E", "I"):  5,  ("C", "E", "J"): 12,
    #     ("C", "F", "H"):  3,  ("C", "F", "I"):  2,  ("C", "F", "J"):  2,
    #     ("C", "G", "H"):  4,  ("C", "G", "I"):  8,  ("C", "G", "J"): 12,

    #     ("D", "E", "H"):  2,  ("D", "E", "I"):  4,  ("D", "E", "J"):  2,
    #     ("D", "F", "H"):  1,  ("D", "F", "I"):  2,  ("D", "F", "J"):  2,
    #     ("D", "G", "H"):  9,  ("D", "G", "I"):  2,  ("D", "G", "J"):  3
    # },


    {
        ('A', 'D', 'F', 'I'):  1,  ('A', 'D', 'F', 'J'):  5,  ('A', 'D', 'F', 'K'):  9,
        ('A', 'D', 'G', 'I'):  1,  ('A', 'D', 'G', 'J'):  10,  ('A', 'D', 'G', 'K'):  2,
        ('A', 'D', 'H', 'I'):  3,  ('A', 'D', 'H', 'J'):  2,  ('A', 'D', 'H', 'K'):  5,

        ('A', 'E', 'F', 'I'):  1,  ('A', 'E', 'F', 'J'):  1,  ('A', 'E', 'F', 'K'):  5,
        ('A', 'E', 'G', 'I'):  5,  ('A', 'E', 'G', 'J'):  12,  ('A', 'E', 'G', 'K'):  2,
        ('A', 'E', 'H', 'I'):  6,  ('A', 'E', 'H', 'J'):  4,  ('A', 'E', 'H', 'K'):  2,

        ('B', 'D', 'F', 'I'):  10,  ('B', 'D', 'F', 'J'):  2,  ('B', 'D', 'F', 'K'):  1,
        ('B', 'D', 'G', 'I'):  7,  ('B', 'D', 'G', 'J'):  2,  ('B', 'D', 'G', 'K'):  3,
        ('B', 'D', 'H', 'I'):  8,  ('B', 'D', 'H', 'J'):  2,  ('B', 'D', 'H', 'K'):  7,

        ('B', 'E', 'F', 'I'):  2,  ('B', 'E', 'F', 'J'):  2,  ('B', 'E', 'F', 'K'):  3,
        ('B', 'E', 'G', 'I'):  4,  ('B', 'E', 'G', 'J'):  5,  ('B', 'E', 'G', 'K'): 13,
        ('B', 'E', 'H', 'I'):  2,  ('B', 'E', 'H', 'J'):  4,  ('B', 'E', 'H', 'K'): 12,

        ('C', 'D', 'F', 'I'): 12,  ('C', 'D', 'F', 'J'):  1,  ('C', 'D', 'F', 'K'):  3,
        ('C', 'D', 'G', 'I'):  1,  ('C', 'D', 'G', 'J'): 15,  ('C', 'D', 'G', 'K'):  9,
        ('C', 'D', 'H', 'I'): 11,  ('C', 'D', 'H', 'J'):  1,  ('C', 'D', 'H', 'K'):  4,

        ('C', 'E', 'F', 'I'):  1,  ('C', 'E', 'F', 'J'):  5,  ('C', 'E', 'F', 'K'):  3,
        ('C', 'E', 'G', 'I'): 11,  ('C', 'E', 'G', 'J'): 13,  ('C', 'E', 'G', 'K'):  3,
        ('C', 'E', 'H', 'I'):  7,  ('C', 'E', 'H', 'J'):  4,  ('C', 'E', 'H', 'K'):  1
    }

 
    # mushroom data
    # {('f', 'b', 'k'): 13, ('f', 'b', 'g'): 17, ('f', 'b', 'p'): 34, ('f', 'b', 'o'): 0, ('f', 'n', 'k'): 2, ('f', 'n', 'g'): 0, ('f', 'n', 'p'): 9, ('f', 'n', 'o'): 0, ('a', 'b', 'k'): 0, ('a', 'b', 'g'): 0, ('a', 'b', 'p'): 0, ('a', 'b', 'o'): 5, ('a', 'n', 'k'): 0, ('a', 'n', 'g'): 0, ('a', 'n', 'p'): 0, ('a', 'n', 'o'): 0}

]


num_prompts = 30
permute = True
# models = ["meta-llama/Llama-3.1-8B-Instruct", "Qwen/Qwen2.5-7B-Instruct-1M","deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"]
models = ["deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"]


# mode = "local" or "api_create_batch" or "api_eval_batch"

for freq in freqs:
    num_samples = 7*(len(freq))
    # num_samples = 1
    # num_samples = 20
    # out = eval_mode(models, freq , num_prompts, permute, mode="local")
    out = eval_mle(models, freq, num_prompts, permute, mode="local")
    # out = eval_cond_mode(models, freq, num_prompts, permute, mode="api_eval_batch")
    # out = eval_cond_mle(models, freq, num_prompts, permute, mode="local")
    # out = eval_joint_sampling(models, freq, num_prompts, num_samples, permute, mode="api_eval_batch")
    # eval_joint_sampling_one(models, freq, num_prompts, num_samples, permute, mode="api_create_batch")
    # out = eval_cond_sampling(models, freq, num_prompts, num_samples, permute, mode="local")
    # out = eval_joint_mode_samples(models, freq, num_prompts, permute, mode="local")
    



 
# plot_line_chart(difficulties, data)

#    2.2=4, 2.3=6, 2.2.2=8, 2.2.3=12
# 5x  20,    30,      40,      60  
# 6x  24,    36,      48,      72