import os
import sys
import uuid
import glob
import time
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import wandb
import numpy as np


# set up DDP (distributed data parallel). torchrun sets this env variable
assert torch.cuda.is_available()
ddp_rank = 0
ddp_world_size = 1
device = f'cuda:{ddp_rank}'
print(f"using device: {device}")
#master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
master_process = True
# convenience variables

from omegaconf import OmegaConf
config = OmegaConf.load(sys.argv[1])

# from Eval_utils import EvalMetric

# metric = EvalMetric(device="cuda",max_length=config.data.sequence_length)

from tqdm import tqdm

import signal

# Define a timeout handler
def handler(signum, frame):
    raise TimeoutError("Function execution timed out")

# Function to set a timeout on your function
def timeout(seconds):
    def decorator(func):
        def wrapper(*args, **kwargs):
            signal.signal(signal.SIGALRM, handler)
            signal.alarm(seconds)  # Set the alarm for 'seconds'
            try:
                return func(*args, **kwargs)
            finally:
                signal.alarm(0)  # Disable the alarm
        return wrapper
    return decorator

@timeout(2)
def call_function(func):
    imports = "\n".join([line for line in func.split("\n") if line.startswith("import ")])
    loc = {}
    exec(imports,{},loc)
    res = exec(func,loc,{})
    return res

def pass_at_k(n, c, k):
    if n - c < k: 
        return 1.0 
    return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

ckpt_path = config.inference.checkpoint
postfix = sys.argv[2]

os.makedirs("./eval_infilling/",exist_ok=True)
preds = []
test_lists = []
N_samples = min(config.inference.N_samples,len(os.listdir(f"inference/{ckpt_path.split('/')[-2]}_{postfix}")))
res = []
res_compiled = []

success_ids = []
for i in range(N_samples):
    preds = torch.load(f"inference/{ckpt_path.split('/')[-2]}_{postfix}/{i}.pt")["pred"]
    test_lists = torch.load(f"inference/{ckpt_path.split('/')[-2]}_{postfix}/{i}.pt")["test_list"]
    c = 0
    N = len(preds)
    for j,pred in enumerate(preds):
        code = "\n".join([pred] + test_lists)
        try:
            _ = call_function(code)
            c += 1
            success_ids.append([i,j])
        except:
            pass
    code_tar = "\n".join([torch.load(f"inference/{ckpt_path.split('/')[-2]}_{postfix}/{i}.pt")["tar"]] + test_lists)
    c_tar = 0
    try:
        _ = call_function(code_tar)
        c_tar += 1
    except:
        pass
    print(N,c)
    print(f"Pass@1: {pass_at_k(N,c,1)}, Pass@5: {pass_at_k(N,c,5)}, Pass@10: {pass_at_k(N,c,10)}, Tar: {c_tar}")
    res.append([c_tar]+[pass_at_k(N,c,k) for k in [1,5,10]])

for i in range(N_samples):
    preds = torch.load(f"inference/{ckpt_path.split('/')[-2]}_{postfix}/{i}.pt")["pred"]
    test_lists = torch.load(f"inference/{ckpt_path.split('/')[-2]}_{postfix}/{i}.pt")["test_list"]
    c = 0
    N = len(preds)
    for pred in preds:
        code = pred
        try:
            _ = call_function(code)
            c += 1
        except:
            pass
    code_tar = torch.load(f"inference/{ckpt_path.split('/')[-2]}_{postfix}/{i}.pt")["tar"]
    c_tar = 0
    try:
        _ = call_function(code_tar)
        c_tar += 1
    except:
        pass
    print(f"Compiled@1: {pass_at_k(N,c,1)}, Compiled@5: {pass_at_k(N,c,5)}, Compiled@10: {pass_at_k(N,c,10)}, Tar: {c_tar}")
    res_compiled.append([pass_at_k(N,c,k) for k in [1,5,10]])

torch.save(np.concatenate([np.array(res).mean(axis=0),np.array(res_compiled).mean(axis=0)],axis=0),f"eval_infilling/{ckpt_path.split('/')[-2]}_{postfix}.pt")
torch.save(success_ids,f"eval_infilling/{ckpt_path.split('/')[-2]}_success_ids_{postfix}.pt")