import copy

basic_config = {
    "name": None,
    "data_kwargs": {
        "ds_name": None,
        "num_prompts": None,
        "prompt_constructor": lambda x: x["input"],
        "split": None,
        "max_input_length": None,
    },
    "model_kwargs": {
        "target_model_str": None,
        "draft_model_str": None,
    },
    "generation_kwargs": {
        "temperature": None,
        "batch_size": None,
        "max_tokens": 128,
    },
    "mcsps_kwargs": {
        "methods": "all",
        "num_candidates_list": list(range(1, 11)),
    },
    "reproducibility_kwargs": {
        "seed": 1234,
    },
    "other_kwargs": {
        "use_cache": True,
    },
}

alpaca_data_kwargs = {
    **basic_config["data_kwargs"],
    "ds_name": "tatsu-lab/alpaca",
    "prompt_constructor": lambda x: f"Write a response that appropriately completes the request. ### Instruction: {x['instruction']} ### Response:",
    "split": "train",
    "max_input_length": 128,
    "num_prompts": 1024,
}

cnn_data_kwargs = {
    **basic_config["data_kwargs"],
    "ds_name": "abisee/cnn_dailymail",
    "prompt_constructor": lambda x: f"Summarize the following article without missing important information: ### Article: {x['article']} ### Summary:",
    "split": "train",
    "subset": "1.0.0",
    "max_input_length": 128,
    "num_prompts": 1024,
}

wmt_data_kwargs = {
    **basic_config["data_kwargs"],
    "ds_name": "radia/wmt14-de2en",
    "prompt_constructor": lambda x: f"Translate the following English text to German. ### English: {x['en']} ### German:",
    "split": "train",
    "max_input_length": 128,
    "num_prompts": 1024,
}

debug_exp1 = {
    **basic_config,
    "name": "debug_exp1",
    "data_kwargs": {
        **basic_config["data_kwargs"],
        "ds_name": "synth",
        "num_prompts": 10240,
    },
    "model_kwargs": {
        **basic_config["model_kwargs"],
        "target_model_str": "huggyllama/llama-7b",
        "draft_model_str": "JackFram/llama-68m",
    },
    "generation_kwargs": {
        **basic_config["generation_kwargs"],
        "temperature": 0.0,
        "batch_size": 32,
    },
}

debug_exp2 = {
    **debug_exp1,
    "name": "debug_exp2",
    "generation_kwargs": {
        **debug_exp1["generation_kwargs"],
        "temperature": 0.7,
    },
}

debug_exp3 = {
    **basic_config,
    "name": "debug_exp3",
    "data_kwargs": {
        **alpaca_data_kwargs,
        "num_prompts": 1024,
        # "num_prompts": 32,
    },
    "model_kwargs": {
        **basic_config["model_kwargs"],
        "target_model_str": "huggyllama/llama-7b",
        "draft_model_str": "JackFram/llama-68m",
    },
    "generation_kwargs": {
        **basic_config["generation_kwargs"],
        "temperature": 0.7,
        "batch_size": 8,
    },
}

debug_exp4 = {
    **debug_exp3,
    "name": "debug_exp4",
    "model_kwargs": {
        **debug_exp3["model_kwargs"],
        "target_model_str": "lmsys/vicuna-7b-v1.3",
        "draft_model_str": "yuhuili/EAGLE-Vicuna-7B-v1.3",
    },
    "generation_kwargs": {
        **debug_exp3["generation_kwargs"],
        "batch_size": 4,
    },
}

debug_exp5 = {
    **debug_exp3,
    "name": "debug_exp5",
    "model_kwargs": {
        **debug_exp3["model_kwargs"],
        "target_model_str": "Qwen/Qwen2-7B-Instruct",
        "draft_model_str": "yuhuili/EAGLE-Qwen2-7B-Instruct",
    },
    "generation_kwargs": {
        **debug_exp3["generation_kwargs"],
        "batch_size": 2,
    },
}


debug_exp6 = {
    **debug_exp3,
    "name": "debug_exp6",
    "model_kwargs": {
        **basic_config["model_kwargs"],
        "target_model_str": "facebook/opt-6.7b",
        "draft_model_str": "facebook/opt-125m",
    },
}

debug_exp7 = {
    **debug_exp3,
    "name": "debug_exp7",
    "data_kwargs":  {
        **cnn_data_kwargs,
        "num_prompts": 32,
    },
    "generation_kwargs": {
        **debug_exp3["generation_kwargs"],
        "batch_size": 8,
    },
}

debug_exp8 = {
    **debug_exp3,
    "name": "debug_exp8",
    "data_kwargs":  {
        **wmt_data_kwargs,
        "num_prompts": 32,
    },
    "generation_kwargs": {
        **debug_exp3["generation_kwargs"],
        "batch_size": 8,
    },
}

debug_exp9 = {
    **debug_exp7,
    "name": "debug_exp9",
    "model_kwargs": {
        **debug_exp7["model_kwargs"],
        "target_model_str": "Qwen/Qwen2-7B-Instruct",
        "draft_model_str": "yuhuili/EAGLE-Qwen2-7B-Instruct",
    },
    "generation_kwargs": {
        **debug_exp7["generation_kwargs"],
        "batch_size": 4,
    },
}

debug_exp10 = {
    **debug_exp8,
    "name": "debug_exp10",
    "model_kwargs": {
        **debug_exp8["model_kwargs"],
        "target_model_str": "Qwen/Qwen2-7B-Instruct",
        "draft_model_str": "yuhuili/EAGLE-Qwen2-7B-Instruct",
    },
    "generation_kwargs": {
        **debug_exp8["generation_kwargs"],
        "batch_size": 4,
    },
}

debug_exp11 = {
    **debug_exp7,
    "name": "debug_exp11",
    "model_kwargs": {
        **debug_exp7["model_kwargs"],
        "target_model_str": "lmsys/vicuna-7b-v1.3",
        "draft_model_str": "yuhuili/EAGLE-Vicuna-7B-v1.3",
    },
    "generation_kwargs": {
        **debug_exp7["generation_kwargs"],
        "batch_size": 4,
    },
}

debug_exp12 = {
    **debug_exp8,
    "name": "debug_exp12",
    "model_kwargs": {
        **debug_exp8["model_kwargs"],
        "target_model_str": "lmsys/vicuna-7b-v1.3",
        "draft_model_str": "yuhuili/EAGLE-Vicuna-7B-v1.3",
    },
    "generation_kwargs": {
        **debug_exp8["generation_kwargs"],
        "batch_size": 4,
    },
}

debug_exp13 = {
    **debug_exp8,
    "name": "debug_exp13",
    "model_kwargs": {
        **debug_exp8["model_kwargs"],
        "target_model_str": "lmsys/vicuna-7b-v1.3",
        "draft_model_str": "yuhuili/EAGLE-Vicuna-7B-v1.3",
    },
    "generation_kwargs": {
        **debug_exp8["generation_kwargs"],
        "batch_size": 4,
        "temperature": 0.0,
    },
}

def short_name(model_name):
    l = model_name.split("/")
    if len(l) > 1:
        return l[-1]
    else:
        return model_name

def get_safe_batch_size(model_name):
    if "qwen2" in model_name.lower():
        return 4
    elif "vicuna" in model_name.lower():
        return 4
    else:
        return 8

# Model type experiemnts
main_exp = [
    {
        **basic_config,
        "name": f"main_exp_{short_name(target_model_name)}_{short_name(draft_model_name)}_{short_name(data_kwargs['ds_name'])}",
        "data_kwargs": data_kwargs,
        "model_kwargs": {
            **basic_config["model_kwargs"],
            "target_model_str": target_model_name,
            "draft_model_str": draft_model_name,
        },
        "generation_kwargs": {
            **basic_config["generation_kwargs"],
            "temperature": 0.7,
            "batch_size": get_safe_batch_size(target_model_name),
        },
    }
    for draft_model_name, target_model_name in zip(
        [
            "yuhuili/EAGLE-Vicuna-7B-v1.3",
            "yuhuili/EAGLE-Qwen2-7B-Instruct",
            "JackFram/llama-68m",
            "facebook/opt-125m",
        ],
        [
            "lmsys/vicuna-7b-v1.3",
            "Qwen/Qwen2-7B-Instruct",
            "huggyllama/llama-7b",
            "facebook/opt-6.7b",
        ],
    )
    for data_kwargs in [
        alpaca_data_kwargs,
        cnn_data_kwargs,
        wmt_data_kwargs,
    ]
]

common_ablation = {
    **basic_config,
    "model_kwargs": {
        **basic_config["model_kwargs"],
        "target_model_str": "huggyllama/llama-7b",
        "draft_model_str": "JackFram/llama-68m",
    },
    "generation_kwargs": {
        **basic_config["generation_kwargs"],
        "temperature": 0.7,
        "batch_size": 8,
    },
}


# Model type experiemnts
ablation_exp1 = [
    {
        **common_ablation,
        "name": f"ablation_exp1_{short_name(target_model_name)}_{short_name(draft_model_name)}_{short_name(data_kwargs['ds_name'])}",
        "data_kwargs": data_kwargs,
        "model_kwargs": {
            **common_ablation["model_kwargs"],
            "target_model_str": target_model_name,
            "draft_model_str": draft_model_name,
        },
    }
    for draft_model_name in ["JackFram/llama-68m", "JackFram/llama-160m"]
    for target_model_name in [
        "huggyllama/llama-7b",
        "huggyllama/llama-13b",
        "huggyllama/llama-30b",
    ]
    for data_kwargs in [
        alpaca_data_kwargs,
        cnn_data_kwargs,
        wmt_data_kwargs,
    ]
]


# Temperature Experiment
ablation_exp2 = [
    {
        **common_ablation,
        "name": f"ablation_exp2_{temperature}_{short_name(data_kwargs['ds_name'])}",
        "data_kwargs": data_kwargs,
        # not necessary to change num_candidates_list as it is not bottleneck
        # "mcsps_kwargs": {
        #     **common_ablation["mcsps_kwargs"],
        #     "num_candidates_list": [3],
        # },
        "generation_kwargs": {
            **common_ablation["generation_kwargs"],
            "temperature": temperature,
        },
    }
    for data_kwargs in [
        alpaca_data_kwargs,
        cnn_data_kwargs,
        wmt_data_kwargs,
    ]
    for temperature in [0.0, 0.1, 0.3, 0.5, 0.7, 1]
]
ablation_exp2_a1 = [
    {
        **common_ablation,
        "name": f"ablation_exp2_a1_{temperature}_{short_name(data_kwargs['ds_name'])}",
        "data_kwargs": data_kwargs,
        # not necessary to change num_candidates_list as it is not bottleneck
        # "mcsps_kwargs": {
        #     **common_ablation["mcsps_kwargs"],
        #     "num_candidates_list": [3],
        # },
        "generation_kwargs": {
            **common_ablation["generation_kwargs"],
            "temperature": temperature,
        },
    }
    for data_kwargs in [
        alpaca_data_kwargs,
        cnn_data_kwargs,
        wmt_data_kwargs,
    ]
    for temperature in [0.72, 0.74, 0.76, 0.78, 0.8, 0.82, 0.84, 0.86, 0.88, 0.9, 0.92, 0.94, 0.96, 0.98]
]



# Num_Candidate Experiments
ablation_exp3 = [
    {
        **common_ablation,
        "name": f"ablation_exp3_{short_name(data_kwargs['ds_name'])}",
        "data_kwargs": data_kwargs,
        # not necessary to change num_candidates_list as it is already 1-10
        # "mcsps_kwargs": {
        #     **common_ablation["mcsps_kwargs"],
        #     "num_candidates_list": [2, 4, 6, 8, 10],
        # },
    }
    for data_kwargs in [
        alpaca_data_kwargs,
        cnn_data_kwargs,
        wmt_data_kwargs,
    ]
]
