import os 
import numpy as np 
import random
from collections import defaultdict
from importlib import import_module
import itertools
import hydra
from helpers.io import * 
from helpers.sampler import Sampler 
from omegaconf import DictConfig, OmegaConf
import time
import ipdb
import multiprocessing
from functools import partial

def check_done(cfg, k, beta):
    save_path = os.path.join(cfg.io.save_root, "results", cfg.task.name, cfg.policy.name, cfg.reward.name)
    filename = os.path.join(save_path, f"{cfg.io.prefix}-{cfg.reward.name}-{cfg.method.name}-k-{k}-beta-{beta}.json") 
    if os.path.isfile(filename):
        print(f"Found existing (method={cfg.method.name}, k={k}, beta={beta}) file at {filename}. Skipping...\n")
        return True 
    return False 

def filter_done(cfg, ks_and_betas):
    dones = [check_done(cfg, k, beta) for k, beta in ks_and_betas]
    return [kb for i, kb in enumerate(ks_and_betas) if not dones[i]]

def _run_prompt_idx(cfg, sampler, ks_and_betas, repeats, prompt_idx):
    start = time.time()
    algo_module = import_module(f"algos.{cfg.method.type}",  package='code')
    algo = algo_module.Algo(cfg, sampler)                                
    outputs = defaultdict(lambda: defaultdict(list))
    indices = defaultdict(lambda: defaultdict(list))
    for (k, beta) in ks_and_betas: 
        outputs[k][beta] = []
        for repeat in range(repeats):
            algo.set_params(prompt_idx, None, k, beta)              
            response, idx = algo.sample_policy()
            outputs[k][beta].append(response)
            indices[k][beta].append(idx)
    end = time.time()
    print(f"\nTime taken for prompt {prompt_idx}: {end - start:.0f} seconds\n")
    return dict(outputs), dict(indices) 

def _run_subsampling_prompt_idx(cfg, sampler, ks_and_betas, repeats, prompt_idx):
    # ipdb.set_trace()
    start = time.time()
    algo_module = import_module(f"algos.{cfg.method.type}",  package='code')
    algo = algo_module.Algo(cfg, sampler)                  
    outputs = defaultdict(lambda: defaultdict(list))
    indices = defaultdict(lambda: defaultdict(list))
    for (repeat, (k, beta)) in itertools.product(range(repeats), ks_and_betas): 
        algo.set_params(prompt_idx, repeat, k, beta)                 
        response, idx = algo.sample_policy()
        outputs[k][beta].append(response)
        indices[k][beta].append(idx)
    end = time.time()
    print(f"\nTime taken for prompt {prompt_idx}: {end - start:.0f} seconds\n")
    return dict(outputs), dict(indices) 

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg):
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)
    task = cfg.task.name
    policy = cfg.policy.name
    reward = cfg.reward.name
    method = cfg.method
    repeats = cfg.repeats
    betas = [0] if method.type in ['bon', 'piref'] else cfg.betas

    OmegaConf.set_struct(cfg, False)    
    cfg.io.load_path = os.path.join(cfg.io.load_root, "data", task, policy)
    save_path = os.path.join(cfg.io.save_root, "results", task, policy, reward)
    os.makedirs(save_path, exist_ok=True)
    if cfg.ks.kmax < 0 or cfg.ks.kmin < 0:
        cfg.ks.kmax = -1
        cfg.ks.kmin = -1
        cfg.ks.inc = 0
        ks = [-1]  
    else:
        ks = get_ks(cfg.ks)
    ks_and_betas = list(itertools.product(ks, betas))
    if cfg.io.overwrite == False:
        print("\n========== Checking for existing results ==========\n")
        ks_and_betas = filter_done(cfg, ks_and_betas)
    print(list(ks_and_betas))

    print(f"Subsampling: {cfg.use_subsampling}")
    sampler = Sampler(cfg, holdout=False)
    # if method == 'rejection':
    #     holdout_sampler = Sampler(cfg, holdout=True) if method.use_holdout else None
    print(f'========= Inference for method {method.name} =========')
    if cfg.use_multiprocessing: 
        prompt_idxs = sampler.prompt_idxs
        if cfg.debug:
            prompt_idxs = prompt_idxs[:10]

        if cfg.use_subsampling:
            run_prompt_idx = partial(_run_subsampling_prompt_idx, cfg, sampler, ks_and_betas, repeats)
        else:
            run_prompt_idx = partial(_run_prompt_idx, cfg, sampler, ks_and_betas, repeats)
        with multiprocessing.Pool(processes=cfg.max_threads) as pool:
            results = pool.map(run_prompt_idx, prompt_idxs)

        for (k, beta) in ks_and_betas:
            outputs = {prompt_idxs[idx] : result[0][k][beta] for idx, result in enumerate(results)}
            save_file = os.path.join(save_path, f"{cfg.io.prefix}-{reward}-{method.name}-k-{k}-beta-{beta}.json")
            json_dump(outputs, save_file)

            indices = {prompt_idxs[idx] : result[1][k][beta] for idx, result in enumerate(results)}
            save_file = os.path.join(save_path, f"{cfg.io.prefix}-{reward}-{method.name}-k-{k}-beta-{beta}-nums.json")
            json_dump(indices, save_file)

    else:
        outputs = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        indices = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        for prompt_idx in sampler.prompt_idxs: 
            if cfg.use_subsampling:
                out, idxs = _run_subsampling_prompt_idx(cfg, sampler, ks_and_betas, repeats, prompt_idx)
            else:
                out, idxs = _run_prompt_idx(cfg, sampler, ks_and_betas, repeats, prompt_idx)
            for (k, beta) in ks_and_betas: 
                outputs[k][beta][prompt_idx] = out[k][beta] 
                indices[k][beta][prompt_idx] = idxs[k][beta] 
            if cfg.debug:
                if prompt_idx > 10:
                    break 

        for (k, beta) in ks_and_betas: 
            save_file = os.path.join(save_path, f"{cfg.io.prefix}-{reward}-{method.name}-k-{k}-beta-{beta}.json")
            json_dump(dict(outputs[k][beta]), save_file)
            save_file = os.path.join(save_path, f"{cfg.io.prefix}-{reward}-{method.name}-k-{k}-beta-{beta}-nums.json")
            json_dump(dict(indices[k][beta]), save_file)


if __name__ == '__main__':
    master_start = time.time()
    main()
    master_end = time.time()
    print(f"\nTotal time taken: {master_end - master_start:.0f} seconds\n")
