import csv


def save_to_csv(epoch, grad_stats, filename="grad_stats.csv"):
    # print(grad_stats)
    with open(filename, mode='a', newline='') as file: 
        writer = csv.writer(file)
        
        file_exists = file.tell() > 0 
        if not file_exists:
            headers = [get_index(key) for key in grad_stats.keys() if get_index(key) is not None]
            writer.writerow(["epoch"] + headers)
        
        writer.writerow([epoch] + [round(value, 5) for value in grad_stats.values()]) 



def get_index(name):
    
    name_mapping = {
        "patch_embed1.proj_conv.weight": 0,
        "patch_embed1.proj1_conv.weight": 1,
        "patch_embed1.proj_res_conv.weight": 2,
        "patch_embed2.proj3_conv.weight": 3,
        "patch_embed2.proj4_conv.weight": 4,
        "patch_embed2.proj_res_conv.weight": 5,
        "patch_embed3.proj3_conv.weight": 6,
        "patch_embed3.proj4_conv.weight": 7,
        "patch_embed3.proj_res_conv.weight": 8,
        "stage1.0.tssa.q_conv.weight": 9,
        "stage1.0.tssa.k_conv.weight": 10,
        "stage1.0.tssa.proj_conv.weight": 11,
        "stage1.0.mlp.mlp1_conv.weight": 12,
        "stage1.0.mlp.mlp2_conv.weight": 13,
        "stage2.0.tssa.q_conv.weight": 14,
        "stage2.0.tssa.k_conv.weight": 15,
        "stage2.0.tssa.proj_conv.weight": 16,
        "stage2.0.mlp.mlp1_conv.weight": 17,
        "stage2.0.mlp.mlp2_conv.weight": 18,
        "stage3.0.ssa.q_conv.weight": 19,
        "stage3.0.ssa.k_conv.weight": 20,
        "stage3.0.ssa.v_conv.weight": 21,
        "stage3.0.ssa.proj_conv.weight": 22,
        "stage3.0.mlp.mlp1_conv.weight": 23,
        "stage3.0.mlp.mlp2_conv.weight": 24,
        "stage3.1.ssa.q_conv.weight": 25,
        "stage3.1.ssa.k_conv.weight": 26,
        "stage3.1.ssa.v_conv.weight": 27,
        "stage3.1.ssa.proj_conv.weight": 28,
        "stage3.1.mlp.mlp1_conv.weight": 29,
        "stage3.1.mlp.mlp2_conv.weight": 30,
        "feedback_stage.decoder.0.linear_td.weight": 31,
        "feedback_stage.decoder.1.linear_td.weight": 32,
        "head.weight": 33
    }
    return name_mapping.get(name, None)