import os
import argparse
import csv

# Linear Probing: classification_only + mlp_only_train
# Finetune: classification_only + None
# Finetune with contrastive: None + None


header= ["model_name", "dataset", "percent", "seed", "method", "contrastive_loss", "train_loss" , "train_acc", "test_loss", "test_acc", "test_acc5"]

command = "python eval_train.py \
    --model_name_or_path bert-base-uncased \
    --max_seq_length 512 \
    --train_percent percentage \
    --output_dir ./result/ \
    --dataset_name dataset_name_ \
    --optimizer optimizer_name \
    --classification_only \
    --mlp_only_train \
    --theta theta_value\
    --seed seed_value\
    --output output_file"

def write_result(name, data):
    assert(len(data)==len(header))
    with open(name, 'a', encoding='UTF8') as f:
        writer = csv.writer(f)
        writer.writerow(data)

def load_result(name):
    dic = {}
    head = None
    with open(name, 'r') as file:
        reader = csv.reader(file)
        counter = 0
        for row in reader:
            counter += 1
            if head is None:
                head = row
            else:
                key = (row[0], row[1], float(row[2]), int(row[3]), row[4])
                value = [float(row[i]) for i in range(5, len(row))]
                dic[key] = value # overwrite here
    return dic, counter

def run_method(args, model_name, dataset, percent, seed, method, theta, optimizer_name):
    result_file = args.output
    if "contrastive" in method:
        model_name += "_"+str(theta)
    print("==================================")
    key = (model_name, dataset, float(str(percent)), seed, method)
    print(key)
    dic, _ = load_result(result_file)
    if key in dic.keys():
        print("Skip")
        return
    
    try:
        out_command = command.replace("dataset_name_", dataset)
        out_command = out_command.replace("optimizer_name", optimizer_name)
        out_command = out_command.replace("percentage", str(percent))
        out_command = out_command.replace("seed_value", str(seed))
        out_command = out_command.replace("output_file", result_file)

        if "linear" in method:
            pass # ("Linear Probing")
        else:
            out_command = out_command.replace("--mlp_only_train", "")
            if "contrastive" not in method:
                pass # ("Finetune")
            else:
                # ("Finetune with contrastive")
                out_command = out_command.replace("--classification_only", "")
        out_command = out_command.replace("theta_value", str(theta))
        print(out_command)
        os.system(out_command)
    except:
        print("**************Fail!!**************")
    return

def main(args):
    result_file = args.output
    print(result_file)
    if not os.path.exists(result_file):
        write_result(result_file, header)

    optimizer_name = args.optimizer
    model_name = "simcse_"+optimizer_name 
    
    dataset_list = ["imdb", "ag_news"]
    seeds = [1]
    percentages = [1.0]

    theta_list = [1.0]

    for seed in seeds:
        for percent in percentages:
            for d in dataset_list:
                run_method(args, model_name, d, percent, seed, "linear_eval_main", 0.0, optimizer_name)
                run_method(args, model_name, d, percent, seed, "fintune_eval_main", 0.0, optimizer_name)
                for theta in theta_list:
                    run_method(args, model_name, d, percent, seed, "finetune_contrastive_eval_main", theta, optimizer_name)
    return

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--optimizer', type=str, default="Adam")
    parser.add_argument('--output', type=str, default="./results.csv", help='output file for eval_train.py')
    args = parser.parse_args()
    main(args = args)