import dataclasses
import datetime

import tqdm
import wandb
import wandb.apis.public


@dataclasses.dataclass(frozen=False)
class File:
    obj: wandb.apis.public.File  # type: ignore
    name: str
    updatedAt: datetime.datetime
    size: int

    @classmethod
    def from_file(cls, f):
        return cls(
            obj=f,
            name=f.name,
            updatedAt=datetime.datetime.strptime(f._attrs["updatedAt"], "%Y-%m-%dT%H:%M:%S"),
            size=f._attrs["sizeBytes"],
        )


api = wandb.Api()

for run in tqdm.tqdm(api.runs("learned-planners")):
    all_files = []
    run_files = []
    for f in run.files():
        run_files.append(File.from_file(f))

    # Keep the latest video, delete all weights
    videos = sorted([f for f in run_files if f.name.endswith("mp4")], key=lambda f: f.updatedAt, reverse=True)
    if len(videos) > 1:
        if not videos[0].updatedAt >= videos[1].updatedAt:
            print(run.name, videos[0], videos[1])
            continue
    all_files.extend(videos[1:])

    weights = [f for f in run_files if f.name.endswith("zip")]
    all_files.extend(weights)
    all_files.sort(key=lambda f: f.size, reverse=True)

    for f in tqdm.tqdm(all_files):
        f.obj.delete()
