import numpy as np 
import random
from collections import defaultdict
import glob, re, os, sys

from helpers.io import * 

def filter_generation_seeds(files, min_idx, max_idx): 
    filtered_files = [] 
    for file in files: 
        match = re.search(r'seed-(\d+)-generations\.json$', file)
        idx = int(match.group(1))
        if idx >= min_idx and idx <= max_idx: 
            filtered_files.append(file)
    return filtered_files


def get_reward_file(cfg, generation_file, reward_key): 
    reward_path = os.path.join(cfg.load_path, reward_key)
    match = re.search(r'\/([^\/]+)-generations\.json$', generation_file)
    if match:
        name = match.group(1) 
        reward_file = os.path.join(reward_path, f"{name}-{reward_key}-rewards.npy")
        if os.path.exists(reward_file): 
            return reward_file
        else: 
            print(f"{reward_file} has no rewards labels.")
            return None 
    print(f"Could not extract name from {generation_file}")
    return None 

def get_generation_files(cfg): 
    # min_idx = cfg.get('min_idx', 101)
    # max_idx = cfg.get('max_idx', 1000)
    load_path = os.path.join(cfg.load_path, 'generations')
    glob_pattern = os.path.join(load_path, f'{cfg.prefix}*.json')
    generation_files = sorted(glob.glob(glob_pattern))
    # filtered_files = filter_generation_seeds(generation_files, min_idx, max_idx)
    return generation_files

def get_generation_and_reward_files(cfg, reward_keys): 
    _generation_files = get_generation_files(cfg)
    _reward_files_dict = {}
    mask = np.ones(len(_generation_files)).astype(bool)
    for reward_key in reward_keys: 
        rfs = [get_reward_file(cfg, gf, reward_key) for gf in _generation_files]
        _mask = np.array([rf is not None for rf in rfs]).astype(bool)
        mask = mask * _mask
        _reward_files_dict[reward_key] = rfs
    generation_files = [gf for (i, gf) in enumerate(_generation_files) if mask[i]]
    reward_files_dict = {
        key: [rf for (i, rf) in enumerate(value) if mask[i]] \
        for key, value in _reward_files_dict.items()
    }
    return generation_files, reward_files_dict

def get_file_idxs_per_repeat(cfg, num_files):
    kmax = int(2 ** cfg.ks.kmax)
    total_samples = num_files * cfg.sampling.k
    available_repeats = total_samples // kmax
    repeats = min(cfg.repeats, available_repeats)
    print(f'Sufficient data for {repeats} repeats out of {cfg.repeats}')
    files_per_repeat = int(np.ceil(kmax / cfg.sampling.k))
    repeat_to_idx = {
        repeat : list(np.arange(repeat*files_per_repeat, (repeat+1)*files_per_repeat)) \
            for repeat in range(repeats)
    }
    return repeat_to_idx

def load_reward_labels(reward_filename): 
        print(f'Loading rewards from {reward_filename}')
        return np.load(reward_filename)

def clean_outputs(outputs, keys=['prompt_idx', 'logprobs', 'response', 'correct']): 
        if set(outputs[0].keys()) != set(keys): 
            return [
                {key: output.get(key) for key in keys if key in output} \
                    for output in outputs
            ]
        return outputs

class Sampler:
    def __init__(self, cfg, holdout=False):
        self.cfg = cfg
        self.holdout = holdout 
        self.prefix = cfg.io.prefix #cfg.holdout.prefix if self.holdout else cfg.io.inputs.prefix

        self.rhat_key = cfg.reward.name
        self.rstar_keys = [] if self.holdout else cfg.task.rstar_keys
        if 'correct' in self.rstar_keys: 
            self.rstar_keys.remove('correct')
        self.reward_keys = [self.rhat_key] + self.rstar_keys

        generation_files, reward_files_dict = get_generation_and_reward_files(cfg.io, self.reward_keys)
        for value in reward_files_dict.values(): 
            assert len(generation_files) == len(value)
        self.generation_files = generation_files 
        self.reward_files_dict = reward_files_dict

        self.repeat_to_file_idxs = get_file_idxs_per_repeat(cfg, len(generation_files))
        self.samples_per_prompt = cfg.sampling.k

        self.outputs = None
        self.kmax = int(2 ** cfg.ks.kmax)
    

    def get_rmax(self): 
        rmaxes = [np.max(self.get_rewards(prompt_idx)) for prompt_idx in self.prompt_idxs]
        return np.max(rmaxes)

    def get_files_per_repeat(self, repeat): 
        file_idxs = self.repeat_to_file_idxs[repeat]
        generation_files = [self.generation_files[idx] for idx in file_idxs]
        reward_files_dict = {
            key: [value[idx] for idx in file_idxs] \
                for key, value in self.reward_files_dict.items()
        }
        return generation_files, reward_files_dict
    
    def load_files(self, repeat): 
        self.outputs = defaultdict(list)
        generation_files, reward_files_dict = self.get_files_per_repeat(repeat)
        self.num_loaded = 0
        for i, generation_file in enumerate(generation_files): 
            reward_file_dict = {
                key: value[i] for (key, value) in reward_files_dict.items()
            }
            self.update_outputs(generation_file, reward_file_dict)
            if self.num_loaded >= self.kmax: 
                break 
        self.outputs = dict(self.outputs)
        self.prompt_idxs = sorted(list(self.outputs.keys()))
        self.rmax = self.get_rmax()

    def update_outputs(self, generation_file, reward_file_dict): 
        outputs = json_load(generation_file)
        outputs = clean_outputs(outputs)
        reward_dict = {
            key : np.load(filename) for (key, filename) in reward_file_dict.items()
        }
        outputs = self.collate(outputs, reward_dict)

        num = min(self.samples_per_prompt, self.kmax - self.num_loaded)
        prompt_idxs, response_idxs = self.get_idxs(outputs, num)
        for prompt_idx in prompt_idxs: 
            self.outputs[prompt_idx] += [
                outputs[prompt_idx * self.samples_per_prompt + response_idx] \
                for response_idx in response_idxs
            ]
        self.num_loaded += len(response_idxs)
    
    def collate(self, outputs, dict): 
        return [{**output, **{key: value[idx] for key, value in dict.items()}} for idx, output in enumerate(outputs)
        ]
            
    def get_idxs(self, outputs, num):
        prompt_idxs = list(set([output['prompt_idx'] for output in outputs]))
        shuffle_idxs = list(range(self.samples_per_prompt))
        random.shuffle(shuffle_idxs)
        return sorted(prompt_idxs), shuffle_idxs[:num]
        
    def get_key(self, key, prompt_idx, idxs): 
        if idxs is None: 
            return [output[key] for output in self.outputs[prompt_idx]]
        else: 
            return [self.outputs[prompt_idx][idx][key] for idx in idxs]

    def get_rewards(self, prompt_idx, idxs=None): 
        return self.get_key(self.rhat_key, prompt_idx, idxs)

    def get_logprobs(self, prompt_idx, idxs=None): 
        return self.get_key('logprobs', prompt_idx, idxs)
    
    def get_outputs(self, prompt_idx, idxs): 
        return [self.outputs[prompt_idx][idx] if idx is not None else None for idx in idxs]



