import os
import sys
import uuid
import glob
import time
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import wandb
import numpy as np


# set up DDP (distributed data parallel). torchrun sets this env variable
assert torch.cuda.is_available()
ddp_rank = 0
ddp_world_size = 1
device = f'cuda:{ddp_rank}'
print(f"using device: {device}")
#master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
master_process = True
# convenience variables

from omegaconf import OmegaConf
config = OmegaConf.load(sys.argv[1])

from Eval_utils import EvalMetric

metric = EvalMetric(device="cuda",max_length=config.data.sequence_length)

from tqdm import tqdm
B = 1

res = {'GPT2': [],
     'GPT2-L': [],
     'GPT3': [],
     'Llama2': [],
     'entropy': []}

ckpt_path = config.inference.checkpoint
N_samples = config.inference.N_samples
postfix = f"_{config.fm_config.t_split}"
os.makedirs("./eval_unconditional/",exist_ok=True)
for i in range(N_samples // B):
    texts = [torch.load(f"inference/{ckpt_path.split('/')[-2]}{postfix}/{i*B+j}.pt")["pred"][0] for j in range(B)]
    r = metric(texts)
    for key,value in r.items():
        res[key].append(value)
    print(r)
for key,value in res.items():
    res[key] = np.mean(res[key])
torch.save(res,f"eval_unconditional/{ckpt_path.split('/')[-2]}{postfix}.pt")


