from functools import partial 
import numpy as np 
import scipy.stats as scs
import os, sys, hydra, ipdb
from collections import defaultdict
from omegaconf.listconfig import ListConfig
from omegaconf import DictConfig, OmegaConf
from plotting.plotter import *
from pathlib import Path
import itertools
from helpers.io import *
import time, re
import multiprocessing
from importlib import import_module
from plotting.plot_beta import BetaPlot

METRIC_NO_RESPONSE = 'no-response'
METRIC_KL = 'kl'


def flatten_list_config(nested_list): 
    return [item for sublist in nested_list for item in (sublist if isinstance(sublist, ListConfig) else [sublist])]

  
def get_files(load_path): 
    files = list(Path(load_path).rglob(f'*.json'))
    return [str(file) for file in files]

def extract_params(prefix, filename):
    if filename.endswith('nums.json'):
        return None
    pattern = rf"/{prefix}-(?!.*repeat)([a-zA-Z0-9-]+)-k-([-\d]+)-beta-([-\d.]+).json"
    pattern_sci_notation = rf"/{prefix}-(?!.*repeat)([a-zA-Z0-9-]+)-k-([-\d]+)-beta-(.*).json"

    match = re.search(pattern, filename)
    match_sci_notation = re.search(pattern_sci_notation, filename)
    if match:
        method = match.group(1)  # the string between 'name-' and '-k'
        k_value = int(match.group(2))  # the number after 'k'
        beta_value = float(match.group(3))  # The float after 'beta'
        if beta_value <= 0:
            beta_value = int(beta_value)
        return method, beta_value, k_value
    elif match_sci_notation:
        method = match_sci_notation.group(1)  
        k_value = int(match_sci_notation.group(2))
        beta_value = float(match_sci_notation.group(3))
        if beta_value <= 0:
            beta_value = int(beta_value)
        return method, beta_value, k_value
    else:
        return None

def filter_files(prefix, files):
    params = [extract_params(prefix, file) for file in files]
    new_files = [file for (param, file) in zip(params, files) if param is not None]
    new_params = [param for param in params if param is not None]
    return new_files, new_params

def get_value(results, output, key, **kwargs):
    method = kwargs['method']
    k = kwargs['k']
    try:
        if key == METRIC_NO_RESPONSE: 
            return (output is None) if 'rejection' in method else False 
        elif key == METRIC_KL and k > 0:
            return -output['logprobs']
        else: 
            return output[key]
    except:
        return None 
    
def _parse_results(result_keys, supplemental_files, file, param):
    (method, beta, k) = param
    print(f"\n\n=====\t({method}, {beta}, {k})\t=====\n")
    outputs_all = json_load(file)
    prompt_idxs = list(outputs_all.keys())
    results = defaultdict(lambda: defaultdict(list))
    for (key, prompt_idx) in itertools.product(result_keys, prompt_idxs):
        outputs = outputs_all[prompt_idx]
        results[key][prompt_idx] = [get_value(None, output, key, method=method, k=k) for output in outputs]
        # if key == METRIC_KL:
        #     print(results[key][prompt_idx])
        #     ipdb.set_trace()
        
    for suffix in supplemental_files:
        filename = file.replace(".json", f"-{suffix}.json")
        if os.path.isfile(filename):
            print(f"\n>>Loading {suffix} supplemental file {filename}\n")
            outputs_all = json_load(filename)
            for prompt_idx, outputs in outputs_all.items():
                results[suffix][prompt_idx] = outputs
    return convert_defaultdict(results)

def parse_results(cfg):
    use_mp = cfg.multiprocessing
    result_keys = flatten_list_config(cfg.metrics) 
    supplemental_files = list(cfg.supplemental_files)
    all_keys = result_keys + supplemental_files
    print(f"\nResult keys:\n")
    for key in all_keys:
        print(f"\t{key}\n")

    task = cfg.task.name
    reward = cfg.reward.name
    policy = cfg.policy.name    
    load_root = cfg.io.load_root
    load_path = os.path.join(load_root, task, policy, reward)

    files = sorted(get_files(load_path))
    files, params = filter_files(cfg.io.prefix, files)
    
    file_parse_results = partial(_parse_results, result_keys, supplemental_files)
    results = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))))
    if use_mp:
        print(f"\nUsing multiprocessing to parse {len(files)} files\n")
        with multiprocessing.Pool() as pool:
            outputs = pool.starmap(file_parse_results, zip(files, params))
        for output, (method, beta, k) in zip(outputs, params):
            for key, value in output.items():
                results[key][method][beta][k] = value
    else: 
        print(f"\nParsing {len(files)} files\n")
        for file, param in zip(files, params):
            output = file_parse_results(file, param)
            (method, beta, k) = param
            for key, value in output.items():
                results[key][method][beta][k] = value
    return results 

    # for file in files: 
    #     method, k, beta = extract_params(prefix, file)
    #     if beta is not None:
    #         print(f"\n\n=====\t({method}, {k}, {beta})\t=====\n")
    #         outputs_all = json_load(file)
    #         prompt_idxs = list(outputs_all.keys())
    #         for (key, prompt_idx) in itertools.product(result_keys, prompt_idxs):
    #             outputs = outputs_all[prompt_idx]
    #             results[key][method][beta][k][prompt_idx] = [get_value(results, output, key, method=method, k=k, prompt_idx=prompt_idx) for output in outputs]
                
    #         for suffix in cfg.supplemental_files:
    #             filename = file.replace(".json", f"-{suffix}.json")
    #             if os.path.isfile(filename):
    #                 print(f"\n>>Loading {suffix} supplemental file {filename}\n")
    #                 outputs_all = json_load(filename)
    #                 for prompt_idx, outputs in outputs_all.items():
    #                     results[suffix][method][beta][k][prompt_idx] = outputs
    #     # else:
    #     #     print(f"\nFilename {file} could not be parsed for parameters")
    # return dict(results)


@hydra.main(config_path="../../configs", config_name="plot", version_base=None)
def main(cfg):
    root = cfg.root
    prefix = cfg.io.prefix
    betas = list(cfg.betas)
    if cfg.ks.kmax < 0 or cfg.ks.kmin < 0:
        cfg.ks.kmax = -1
        cfg.ks.kmin = -1
        cfg.ks.inc = 0
        ks = [-1]
    else:
        ks = get_ks(cfg.ks)
    

    print("\n========== Loading data ==========\n")
    parsed_results_dir = os.path.join(root, "parsed_results")
    os.makedirs(parsed_results_dir, exist_ok=True)
    parsed_results_filename = os.path.join(parsed_results_dir, f"{prefix}.json")

    if cfg.refresh_data:
        print('Parsing results')
        parse_start = time.time()
        results = parse_results(cfg)
        parse_end = time.time()
        print(f"\nTime taken to parse data: {parse_end - parse_start:.0f} seconds\n")
        json_dump(results, parsed_results_filename)
        del results

    print(f'Loading parsed results from {parsed_results_filename}')
    data_load_start = time.time()
    results = json_load(parsed_results_filename)
    data_load_end = time.time()
    print(f"\nTime taken to load data: {data_load_end - data_load_start:.0f} seconds\n")
    
    
    print("\n========== Plotting ==========\n")
    if 'audreyh' in root: 
        OmegaConf.set_struct(cfg, False)    
        cfg.io.save_root = '/u/audreyh/workspace/test-code/figures'
    os.makedirs(cfg.io.save_root, exist_ok=True)

    plot_start = time.time()
    if cfg.ks.kmax < 0:
        plot = BetaPlot(cfg, results['correct'], results['nums'])
        plot.plot()
        plot.save_figure()
    else:
        metrics = flatten_list_config(cfg.metrics) # cfg.task.rstar_keys + [cfg.reward.name]
        plotter = RewardvsNPlotter(metrics, cfg, results)
        plotter.plot(ks)
    plot_end = time.time()
    print(f"\nTime taken to plot: {plot_end - plot_start:.0f} seconds\n")



if __name__ == '__main__':
    master_start = time.time()
    main()
    master_end = time.time()
    print(f"\nTotal time taken: {master_end - master_start:.0f} seconds\n")


