import glob
import json

import pandas as pd
from loguru import logger
from tqdm import tqdm
from wandb.proto import wandb_internal_pb2
from wandb.sdk.internal import datastore


class WandbDataLoader:
    def __init__(self, files_path):
        self.files_path = files_path
        self.serieses = []

    def load_data(self):
        files = glob.glob(self.files_path)
        for data_path in tqdm(sorted(files)):
            ds = datastore.DataStore()
            ds.open_for_scan(data_path)
            series = self._process_data(ds)
            self.serieses.append(series)
        return pd.DataFrame(self.serieses)

    def _extract_summary_data(self, pb):
        d = {}
        for update in pb.summary.update:
            d[update.key] = json.loads(update.value_json)
        return d

    def _extract_run_data(self, pb):
        d = {"project": pb.run.project}
        for update in pb.run.config.update:
            if update.key == "dataset_name":
                d["dataset_name"] = json.loads(update.value_json)["name"]
            else:
                d[update.key] = json.loads(update.value_json)
        return d

    def _process_data(self, datastore):
        series = pd.Series()
        while True:
            data = datastore.scan_record()
            if data is None:
                break
            pb = wandb_internal_pb2.Record()
            pb.ParseFromString(data[1])
            record_type = pb.WhichOneof("record_type")
            d = {}
            if record_type == "summary":
                d = self._extract_summary_data(pb)
            elif record_type == "run":
                d = self._extract_run_data(pb)
            series = self._update_series(series, d)
        return series

    def _update_series(self, series, data):
        return pd.concat((series, pd.Series(data))) if len(series) else pd.Series(data)
