import wandb
import random
import os
import csv

def process_kl(ckpt_path):
    # Speficy the path of 'train.kl' file, which records the KL divergence over in training
    # Covert KL values to csv file
    kl_path = os.path.join(ckpt_path, 'train.kl')
    csv_path = os.path.join(ckpt_path, 'train.csv')

    column_names = ['dim_0',
                    'dim_1',
                    'dim_2',
                    'dim_3',
                    'dim_4',
                    'dim_5',
                    'dim_6',
                    'dim_7',
                    'dim_8',
                    'dim_9']
    kls = []

    ## Read in kl values
    with open(kl_path, 'r', encoding='UTF8') as klfile:
        for line in klfile:
            dim_kls = line.split(',')
            dim_kls[0] = dim_kls[0].split(':')[-1]
            dim_kls[-1] = dim_kls[-1].split('\n')[0]
            dim_kls = [float(i) for i in dim_kls]
            kls.append(dim_kls)

    ## Write row into csv
    with open(csv_path, 'w', newline='') as csvfile:
        csv_writer = csv.writer(csvfile)
        csv_writer.writerow(column_names)
        for row in kls:
            csv_writer.writerow(row)


column_idx = 9

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="RobustVAE",
    
    # track hyperparameters and run metadata
    config={
    "architecture": "RobustVAE",
    "dataset": "traffic",
    },
    
    name=f"dim_{column_idx}"
)


ckpt_path = './checkpoints/Traffic_128_c11_0.15_semi0.1p_traffic'
csv_path = os.path.join(ckpt_path, 'train.csv')

# process_kl(ckpt_path)

# simulate training
skip_head = 1
with open(csv_path, 'r') as csvfile:
    csv_reader = csv.reader(csvfile, delimiter=',')

    for row in csv_reader:
        if skip_head:
            skip_head -= 1
            continue
        wandb.log({"kl": float(row[column_idx])})

wandb.finish()

# for epoch in :

    
    # log metrics to wandb
    # wandb.log({"acc": acc, "loss": loss})
    
# [optional] finish the wandb run, necessary in notebooks
# wandb.finish()