import numpy as np 
import random
from collections import defaultdict
import glob, re, os, sys, time

import json, csv, ipdb
from helpers.io import * 
import helpers.test_functions as tf
import pandas as pd

def get_generation_files(cfg): 
    load_path = os.path.join(cfg.load_path, 'generations')
    glob_pattern = os.path.join(load_path, f'{cfg.prefix}*.json')
    generation_files = sorted(glob.glob(glob_pattern))
    return generation_files

def get_reward_file(generation_file, reward): 
    return str(generation_file).replace('generations', f'{reward}').replace('.json', '-rewards.npy')


def load_reward(generation_file, reward, length):
    rf = get_reward_file(generation_file, reward)
    if os.path.isfile(rf):
        return np.load(rf)
    else:
        return np.full(length, np.nan)

 
def load_and_collate(generation_file, reward_keys): 
    outputs = json_load(generation_file)
    reward_dict = {
        reward_key : load_reward(generation_file, reward_key, len(outputs)) for reward_key in reward_keys 
    }
    return collate(outputs, reward_dict)


def collate(outputs, dict): 
    return [{**output, **{key: value[idx] for key, value in dict.items()}} for idx, output in enumerate(outputs)]

def append_json_to_csv(data, filename):
    cols = data[0].keys()    
    idx = data[0]['prompt_idx']
    filename = filename.format(idx=idx)
    write_header = False if os.path.exists(filename) else True
    if write_header:
        assert all(jdx == idx for jdx in tf.get_key(data, 'prompt_idx')), "Multiple prompt idxs found per prompt chunk; check indexing."
    with open(filename, mode='a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=cols, escapechar='\\')
        if write_header:
            writer.writeheader()
        writer.writerows(data)
    return idx

def get_reward_dirs(path):
    return [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d)) and 'generations' not in d]

def prepare_data(cfg, filename, k=128):
    generation_files = get_generation_files(cfg)
    reward_keys = get_reward_dirs(cfg.load_path)
    prompt_idxs = set()
    for gf in generation_files:
        start = time.time()
        outputs = load_and_collate(gf, reward_keys)
        for i in range(0, len(outputs), k):
            block = outputs[i : i+k]
            prompt_idx = append_json_to_csv(block, filename)
            prompt_idxs.add(prompt_idx)
        end = time.time()
        print(f"\nData parsing took {end - start:.0f} seconds for {gf.split('/')[-1]}\n")
        
    return list(prompt_idxs)

def csv_to_outputs(filename):
    # with open(filename, mode='r', newline='') as f:
    #     reader = csv.DictReader(f)
    #     rows = list(reader)
    df = pd.read_csv(filename)
    rows = df.to_dict('records')

    return rows

def cleanup_dir(directory):
    for filename in os.listdir(directory):
        filepath = os.path.join(directory, filename)
        if os.path.isfile(filepath) and 'generations' not in filepath:
            print(f'Removing {filename}')
            os.remove(filepath)

def get_prompt_idxs(directory):
    glob_pattern = os.path.join(directory, f'*.csv')
    filenames = glob.glob(glob_pattern)
    pattern = r'prompt_(\d+)\.csv'
    indices = [int(re.search(pattern, filename).group(1)) for filename in filenames]
    return sorted(list(set(indices)))

def convert_types(data, reward_key):
    type_map = {
        reward_key : float,
        'prompt_idx': int,
        'logprobs': float,
        'correct': eval,
    }
    return list(map(lambda row: {key: (type_map[key](value) if key in type_map else value) for key, value in row.items()}, data))

def remove_nans(outputs, reward_key):
    return [output for output in outputs if not np.isnan(output[reward_key])]

def dir_is_empty(path):
    return len(os.listdir(path)) == 0


class Sampler:
    def __init__(self, cfg, holdout=False):
        np.random.seed(cfg.seed)
        random.seed(cfg.seed)
        self.cfg = cfg
        self.rhat_key = cfg.reward.name
        self.holdout = holdout 
        self.prefix = cfg.io.prefix 

        self.csv_path = cfg.io.load_path.replace('data', 'parsed_data')
        os.makedirs(self.csv_path, exist_ok=True)
        self.filename = os.path.join(f"{self.csv_path}", f"{cfg.task.name}_{cfg.policy.name}_prompt_{{idx}}.csv")

        self.outputs = None
        self.prompt_idx = None

        self.initialize()
    
    def initialize(self):
        if self.cfg.refresh_data or dir_is_empty(self.csv_path):
            cleanup_dir(self.csv_path)
            prompt_idxs = prepare_data(self.cfg.io, self.filename, k=self.cfg.sampling.k)
        else:
            prompt_idxs = get_prompt_idxs(self.csv_path)
        self.prompt_idxs = prompt_idxs

    def get_rmax(self): 
        self.outputs = self.all_outputs
        rmax = np.max(self.get_rewards())
        self.outputs = None
        return rmax

    def load_files(self, prompt_idx):
        outputs = csv_to_outputs(self.filename.format(idx=prompt_idx))
        self.all_outputs = remove_nans(outputs, self.rhat_key)
        self.prompt_idx = prompt_idx
        self.rmax = self.get_rmax()

    def sample_outputs(self, k):
        if k < 0:
            if self.outputs is None:
                self.outputs = np.copy(self.all_outputs)
            np.random.shuffle(self.outputs)
        else: 
            self.outputs = np.random.choice(self.all_outputs, size=k, replace=True) 
    
    def get_key(self, key): 
        return [output[key] for output in self.outputs]

    def get_rewards(self): 
        return self.get_key(self.rhat_key)

    def get_logprobs(self): 
        return self.get_key('logprobs')
    
    def get_outputs(self, idxs): 
        if isinstance(idxs, list):
            return [self.outputs[idx] if idx is not None else None for idx in idxs]
        else:
            return self.outputs[idxs] if idxs is not None else None
            



