import numpy as np
from gen_data import generate_W_q_W_k
from eval import calculate_loss, calculate_loss_wanda_setting, calculate_loss_sparse
from pruning import pruning_mask_wanda, pruning_mask_sparse_gpt, pruning_mask_gd

import argparse

def parse_arguments():
    parser = argparse.ArgumentParser(description='choose the experiment to run')
    parser.add_argument('--run_gd_and_wanda', type=bool, default=True, help='Run wanda and our method')
    parser.add_argument('--run_sparse_gpt', type=bool, default=True, help='Run sparse GPT')
    return parser.parse_args()

def main():
    args = parse_arguments()
    run_gd_and_wanda = args.run_gd_and_wanda
    run_sparse_gpt = args.run_sparse_gpt

    if run_gd_and_wanda:
        X_len = 64
        C = 10
        random_seed_list = [43, 44, 45, 46, 47]


        wanda_loss_fig1 = []
        our_loss_fig1 = []

        wanda_loss_fig2 = []
        our_loss_fig2 = []

        wanda_loss_fig3 = []
        our_loss_fig3 = []



        for run in range(5):
            np.random.seed(random_seed_list[run])
            print(f"run: {run}")
            # get data for figure 1
            lam = 0.04
            rho = 0.5
            n_list = [64,128,256,512,1024]
            d = 64

            wanda_loss_list_fig1 = []
            our_loss_list_fig1 = []
            
            for n in n_list:    
                print(f"n: {n}")
                
                X_list = [np.random.normal(0,1,(n, d)) for _ in range(X_len)]
                W_q = generate_W_q_W_k(d,C)
                W_k = generate_W_q_W_k(d,C)
                W = W_q @ W_k.T
                
                M_c = np.tril(np.ones((n, n)))
                M_q_wanda_setting, M_k_wanda_setting = pruning_mask_wanda(X_list, W_q, W_k, rho)
                wanda_loss_list_fig1.append(float(calculate_loss_wanda_setting(X_list, W_q, W_k, M_q_wanda_setting, M_k_wanda_setting, M_c)))
                print(f"wanda_loss: {wanda_loss_list_fig1[-1]}")
                M_our = pruning_mask_gd(X_list, W, lam, rho)
                our_loss_list_fig1.append(float(calculate_loss(X_list, W, M_our, lam, M_c)))
                print(f"our_loss: {our_loss_list_fig1[-1]}")
                
            wanda_loss_fig1.append(wanda_loss_list_fig1)
            our_loss_fig1.append(our_loss_list_fig1)

            # get data for figure 2
            n = 128
            d = 64
            lambda_list = [1/1024,1/512,1/256,1/128,1/64,1/32,1/16,1/8,1/4,1/2,1,2,4,8,16]


            wanda_loss_list_fig2 = []
            our_loss_list_fig2 = []

            X_list = [np.random.normal(0,1,(n, d)) for _ in range(X_len)]
            W_q = generate_W_q_W_k(d,C)
            W_k = generate_W_q_W_k(d,C)
            W = W_q @ W_k.T
            M_c = np.tril(np.ones((n, n)))

            M_q_wanda_setting, M_k_wanda_setting = pruning_mask_wanda(X_list, W_q, W_k, rho)
            wanda_fig2_loss_single= calculate_loss_wanda_setting(X_list, W_q, W_k, M_q_wanda_setting, M_k_wanda_setting, M_c)
            wanda_loss_list_fig2 = [float(wanda_fig2_loss_single)] * len(lambda_list)
            wanda_loss_fig2.append(wanda_loss_list_fig2)

            for lam in lambda_list:
                print(f"lambda: {lam}")
                M_gd = pruning_mask_gd(X_list, W, lam, rho)
                our_loss_list_fig2.append(float(calculate_loss(X_list, W, M_gd, lam, M_c)))
            our_loss_fig2.append(our_loss_list_fig2)


            # get data for figure 3
            n = 128
            d = 64
            X_list = [np.random.normal(0,1,(n, d)) for _ in range(X_len)]
            
            rho_list = [0.4,0.5,0.6,0.7,0.8,0.9,1]

            M_c = np.tril(np.ones((n, n)))

            wanda_loss_list_fig3 = []
            our_loss_list_fig3 = []
            W_q = generate_W_q_W_k(d,C)
            W_k = generate_W_q_W_k(d,C)
            W = W_q @ W_k.T

            for rho in rho_list:
                print(f"rho: {rho}")
                lam = 0.04
                M_q_wanda_setting, M_k_wanda_setting = pruning_mask_wanda(X_list, W_q, W_k, rho)
                wanda_loss = calculate_loss_wanda_setting(X_list, W_q, W_k, M_q_wanda_setting, M_k_wanda_setting, M_c)

                M_gd = pruning_mask_gd(X_list, W, lam, rho)
                our_loss = calculate_loss(X_list, W, M_gd, lam, M_c)

                wanda_loss_list_fig3.append(float(wanda_loss))
                our_loss_list_fig3.append(float(our_loss))
            
            wanda_loss_fig3.append(wanda_loss_list_fig3)
            our_loss_fig3.append(our_loss_list_fig3)
        
        wanda_var_fig1 = np.var(wanda_loss_fig1, axis=0)
        our_var_fig1 = np.var(our_loss_fig1, axis=0)
        wanda_var_fig2 = np.var(wanda_loss_fig2, axis=0)
        our_var_fig2 = np.var(our_loss_fig2, axis=0)
        wanda_var_fig3 = np.var(wanda_loss_fig3, axis=0)
        our_var_fig3 = np.var(our_loss_fig3, axis=0)

        wanda_loss_fig1 = np.mean(wanda_loss_fig1, axis=0)
        our_loss_fig1 = np.mean(our_loss_fig1, axis=0)
        wanda_loss_fig2 = np.mean(wanda_loss_fig2, axis=0)
        our_loss_fig2 = np.mean(our_loss_fig2, axis=0)
        wanda_loss_fig3 = np.mean(wanda_loss_fig3, axis=0)
        our_loss_fig3 = np.mean(our_loss_fig3, axis=0)

        with open(f"result_for_{X_len}_length_wanda_and_ours.txt", "w") as f:

            f.write(f"Figure 1\n")
            f.write(f"n_list: {n_list}\n")
            f.write(f"wanda_loss_fig1: {wanda_loss_fig1}\n")
            f.write(f"our_loss_fig1: {our_loss_fig1}\n")
            f.write(f"wanda_var_fig1: {wanda_var_fig1}\n")
            f.write(f"our_var_fig1: {our_var_fig1}\n")

            f.write(f"Figure 2\n")
            f.write(f"lambda_list: {lambda_list}\n")
            f.write(f"wanda_loss_fig2: {wanda_loss_fig2}\n")
            f.write(f"our_loss_fig2: {our_loss_fig2}\n")
            f.write(f"wanda_var_fig2: {wanda_var_fig2}\n")
            f.write(f"our_var_fig2: {our_var_fig2}\n")

            f.write(f"Figure 3\n")
            f.write(f"rho_list: {rho_list}\n")
            f.write(f"wanda_loss_fig3: {wanda_loss_fig3}\n")
            f.write(f"our_loss_fig3: {our_loss_fig3}\n")
            f.write(f"wanda_var_fig3: {wanda_var_fig3}\n")
            f.write(f"our_var_fig3: {our_var_fig3}\n")


    if run_sparse_gpt:
        X_len = 64
        C = 10
        random_seed_list = [43, 44, 45, 46, 47]

        sparse_gpt_loss_fig1 = []
        sparse_gpt_loss_fig2 = []
        sparse_gpt_loss_fig3 = []

        for run in range(5):
            d = 64
            rho = 0.5
            B = 16
            Bs = 1
            np.random.seed(random_seed_list[run])
            print(f"run: {run}")

            sparse_gpt_loss_list_fig1 = []
            # get data for figure 1
            n_list = [64,128,256,512,1024]
            for n in n_list:
                X_list = [np.random.normal(0,1,(n, d)) for _ in range(X_len)]
                W_q = generate_W_q_W_k(d,C)
                W_k = generate_W_q_W_k(d,C)

                M_c = np.tril(np.ones((n, n)))
                M_q_sparse_gpt, M_k_sparse_gpt, W_q_update_sparse_gpt, W_k_update_sparse_gpt = pruning_mask_sparse_gpt(X_list, W_q, W_k, rho, B, Bs)
                sparse_gpt_loss = calculate_loss_sparse(X_list, W_q, W_k, M_q_sparse_gpt, M_k_sparse_gpt, M_c,W_q_update_sparse_gpt,W_k_update_sparse_gpt)
                sparse_gpt_loss_list_fig1.append(sparse_gpt_loss)
                print(f"n = {n}, sparse_gpt_loss: {sparse_gpt_loss}")

            sparse_gpt_loss_fig1.append(sparse_gpt_loss_list_fig1)

            # get data for figure 2
            n = 128
            d = 64
            lambda_list = [1/1024,1/512,1/256,1/128,1/64,1/32,1/16,1/8,1/4,1/2,1,2,4,8,16]

            sparse_gpt_loss_list_fig2 = []
            X_list = [np.random.normal(0,1,(n, d)) for _ in range(X_len)]
            W_q = generate_W_q_W_k(d,C)
            W_k = generate_W_q_W_k(d,C)
            W = W_q @ W_k.T
            M_c = np.tril(np.ones((n, n)))

            M_q_sparse_gpt, M_k_sparse_gpt, W_q_update_sparse_gpt, W_k_update_sparse_gpt = pruning_mask_sparse_gpt(X_list, W_q, W_k, rho, B, Bs)
            sparse_gpt_fig2_loss_single= calculate_loss_sparse(X_list, W_q, W_k, M_q_sparse_gpt, M_k_sparse_gpt, M_c,W_q_update_sparse_gpt,W_k_update_sparse_gpt)

            sparse_gpt_loss_list_fig2 = [float(sparse_gpt_fig2_loss_single)] * len(lambda_list)
            print(f"sparse_gpt_loss_baseline_lambda: {sparse_gpt_fig2_loss_single}")
            sparse_gpt_loss_fig2.append(sparse_gpt_loss_list_fig2)

            # get data for figure 3
            n = 128
            d = 64
            X_list = [np.random.normal(0,1,(n, d)) for _ in range(X_len)]

            rho_list = [0.4,0.5,0.6,0.7,0.8,0.9]

            M_c = np.tril(np.ones((n, n)))

            sparse_gpt_loss_list_fig3 = []
            W_q = generate_W_q_W_k(d,C)
            W_k = generate_W_q_W_k(d,C)

            for rho in rho_list:
                M_q_sparse_gpt, M_k_sparse_gpt, W_q_update_sparse_gpt, W_k_update_sparse_gpt = pruning_mask_sparse_gpt(X_list, W_q, W_k, rho, B, Bs)
                sparse_gpt_loss = calculate_loss_sparse(X_list, W_q, W_k, M_q_sparse_gpt, M_k_sparse_gpt, M_c,W_q_update_sparse_gpt,W_k_update_sparse_gpt)
                print(f"rho={rho}, sparse_gpt_loss: {sparse_gpt_loss}")
                sparse_gpt_loss_list_fig3.append(float(sparse_gpt_loss))

            sparse_gpt_loss_fig3.append(sparse_gpt_loss_list_fig3)

        sparse_gpt_var_fig1 = np.var(sparse_gpt_loss_fig1, axis=0)
        sparse_gpt_var_fig2 = np.var(sparse_gpt_loss_fig2, axis=0)
        sparse_gpt_var_fig3 = np.var(sparse_gpt_loss_fig3, axis=0)

        sparse_gpt_loss_fig1 = np.mean(sparse_gpt_loss_fig1, axis=0)
        sparse_gpt_loss_fig2 = np.mean(sparse_gpt_loss_fig2, axis=0)
        sparse_gpt_loss_fig3 = np.mean(sparse_gpt_loss_fig3, axis=0)

        # store result of 3 figures in a txt file

        with open(f"result_for_{X_len}_length_sparse_gpt.txt", "w") as f:            
            f.write(f"Figure 1\n")
            f.write(f"n_list: {n_list}\n")
            f.write(f"sparse_gpt_loss_fig1: {sparse_gpt_loss_fig1}\n")
            f.write(f"sparse_gpt_var_fig1: {sparse_gpt_var_fig1}\n")

            f.write(f"Figure 2\n")
            f.write(f"lambda_list: {lambda_list}\n")
            f.write(f"sparse_gpt_loss_fig2: {sparse_gpt_loss_fig2}\n")
            f.write(f"sparse_gpt_var_fig2: {sparse_gpt_var_fig2}\n")

            f.write(f"Figure 3\n")
            f.write(f"rho_list: {rho_list}\n")
            f.write(f"sparse_gpt_loss_fig3: {sparse_gpt_loss_fig3}\n")
            f.write(f"sparse_gpt_var_fig3: {sparse_gpt_var_fig3}\n")

if __name__ == "__main__":
    main()







