import os

import torch
import wandb

from utils import plot_class_count
from sample import Sampler

def log_artifacts(art, art_name, file_extension=".pickle", path='', upload=False) -> None:
    # upload to wandb server
    path = os.path.join('../artifacts', wandb.run.id, path)
    os.makedirs(path, exist_ok=True)
    file_path = os.path.join(path, art_name + file_extension)
    torch.save(art, file_path)
    if upload:
        wandb.save(file_path, base_path='../artifacts')
        
def log_settings(cfg, sampler: Sampler):
    """Log settings of this trial to wandb.

    Args:
        cfg (DictConfig): (local) cfg of this trial
        sampler (sample.Sampler): an object used to data and distribute
    """
    log_artifacts(cfg, "cfg", upload=True)
    log_artifacts(sampler.data_idxs_for_c, "data_idxs_for_c", path='data_idxs', upload=True)
    log_artifacts(sampler.data_idxs_s, "data_idxs_s", path='data_idxs', upload=True)
    log_artifacts(sampler.n_data_for_cls_for_c, "n_data_for_cls_for_c", path='data_idxs', upload=True)
    wandb.log({"class_count": wandb.Image(plot_class_count(cfg, sampler.n_data_for_cls_for_c))}, step=0)
    client_imgs = {f'c{i}_img': wandb.Image(dset[0][0].numpy().transpose(1, 2, 0)) for i, dset in enumerate(sampler.dset_chunks) if i < 2}
    wandb.log(dict(client_imgs), step=0)
    wandb.log({'test_img': wandb.Image(sampler.dset_test[0][0])}, step=0)
    try:
        wandb.log({'s_img': wandb.Image(sampler.dset_chunk_s[0][0])}, step=0)
    except:
        pass

def log_state_dicts(device, device_name, step):
    """Log state_dicts of a device

    Args:
        device (_type_): _description_
        device_name (_type_): _description_
        step (_type_): _description_
    """
    state_dicts = device.get_state_dicts()
    path = os.path.join('state_dicts', device_name)
    log_artifacts(art=state_dicts, art_name=f'r{step}', file_extension='.pt', path=path, upload=False)
