import os
import itertools
import logging
from functools import wraps
from time import time
from collections import defaultdict
from pathlib import Path
from typing import List, Dict, Tuple, Union, Any
import random
import json

from tqdm import tqdm
import torch
from sklearn.random_projection import SparseRandomProjection
from huggingface_hub import login
import numpy as np
from omegaconf import DictConfig

from inference_rlhf.code.tasks.base import BaseDataLoader
from inference_rlhf.code.helpers.io import json_load

log = logging.getLogger(__name__)

def hf_login(cfg):
    """
    Loads huggingface token from file and logs in
    """

    with open(cfg.hf_token_path, 'r') as f:
        hf_token = f.read()
    
    login(token=hf_token, add_to_git_credential=False)
    del hf_token
    log.info('Logged in to Huggingface!')


def timing(f):
    @wraps(f)
    def wrap(*args, **kw):
        start = time()
        result = f(*args, **kw)
        end = time()
        log.info(f'Function {f.__name__} took: {end-start:.4f} sec')
        return result
    return wrap

def get_generations_path(cfg):
    """
    Returns the path to the directory of generations.
    """
    master_parent = cfg.io.load_root
    sub_path = os.path.join("data", cfg.task.name, f'{cfg.policy.name}', 'generations')
    return os.path.join(master_parent, sub_path)

def estimate_pass_at_k(num_samples, num_correct, k):
    """
    Estimates pass@k of each problem and returns them in an array.
    
    Copied from https://github.com/huggingface/evaluate/blob/main/metrics/code_eval/code_eval.py#L198
    """

    def estimator(n: int, c: int, k: int) -> float:
        """Calculates 1 - comb(n - c, k) / comb(n, k)."""
        if n - c < k:
            return 1.0
        return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

    if isinstance(num_samples, int):
        num_samples_it = itertools.repeat(num_samples, len(num_correct))
    else:
        assert len(num_samples) == len(num_correct)
        num_samples_it = iter(num_samples)

    return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])

def set_seeds(seed: int):
    """
    Set seeds for numpy, and torch.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def rget_json_files_from_dir(load_path: str) -> List[str]:
    """
    Recursively get all json files from a directory.
    """
    files = list(Path(load_path).rglob(f'*.json'))
    return [str(file) for file in files]

def load_ref_data(load_path: str) -> Dict[int, Dict[str, List[bool]]]:
    # Load ref data
    generation_files = rget_json_files_from_dir(load_path)

    if 'mbpp' in load_path:
        generation_files = [gf for gf in generation_files if '--CHECKED' in gf]

    all_data = defaultdict(dict)
    for gf in tqdm(generation_files, desc="Loading ref data"):
        # Check if file exists
        if not os.path.exists(gf):
            log.info(f'WARNING: Ref generation file {gf} does not exist')
            continue

        data = json_load(gf)
        prompt_idx = data[0]['prompt_idx']

        results_ref = [d['strict_correct'] for d in data]

        all_data[prompt_idx]['results_ref'] = results_ref

    return all_data

@timing
def load_response_data(
    load_path: str, 
    load_features: bool,
    max_num_files: int,
    feature_name: str,
    feature_type: str = 'mean_hidden_state',
    temp: float = 1.0,
    top_p: float = 1.0,
    min_p: float = 0.0,
) -> Dict[int, Dict[str, List[Any]]]:
    """
    Load response data from JSON generation files.

    Args:
        load_path (str): Path to the directory containing the generation files.
        load_features (bool): Whether to load the features.
        max_num_files (Optional[int]): Maximum number of files to load. If None, load all files.
        feature_name (str): Name of the feature to load.
        feature_type (str): Type of feature to load (e.g., 'mean_hidden_state').
        temp (float): Temperature of the generation.
        top_p (float): Top-p of the generation.
        min_p (float): Minimum probability of the generation.

    Returns:
        Dict[int, Dict[str, List[Any]]]: Dictionary with all the loaded data.
    """
    # Load json files
    generation_files = rget_json_files_from_dir(load_path)
    generation_files = [gf for gf in generation_files if 'debug' not in gf and 'coreset' not in gf and 'armo-rm' not in gf]
    generation_files = [gf for gf in generation_files if f'temp_{temp}' in gf or f'temp-{temp}' in gf]
    generation_files = [gf for gf in generation_files if f'top-p-{top_p}' in gf]

    if min_p > 0.0:
        generation_files = [gf for gf in generation_files if f'--min-p-{min_p}' in gf]
    else:
        generation_files = [gf for gf in generation_files if not '--min-p' in gf]

    if 'code_contests' in load_path or 'mbpp' in load_path:
        generation_files = [gf for gf in generation_files if '--CHECKED' in gf]

    if load_features:
        if 'code_contests' in load_path or 'mbpp' in load_path:
            generation_files = [gf for gf in generation_files if os.path.exists(gf.replace('--CHECKED.json', f'--{feature_name}-features.tar'))]
        else:
            generation_files = [gf for gf in generation_files if os.path.exists(gf.replace('.json', f'--{feature_name}-features.tar'))]

    if max_num_files is not None:
        generation_files = generation_files[:max_num_files]

    log.info(f'Found {len(generation_files)} generation files')

    all_data = defaultdict(lambda: defaultdict(list))
    for gf in tqdm(generation_files, desc="Loading response data"):
        data = json_load(gf)
            
        # Load the features
        if load_features:
            if 'code_contests' in load_path or 'mbpp' in load_path:
                features = torch.load(gf.replace('--CHECKED.json', f'--{feature_name}-features.tar'))
            else:
                log.info(f'Loading features from {gf.replace(".json", f"--{feature_name}-features.tar")}')
                features = torch.load(gf.replace('.json', f'--{feature_name}-features.tar'))
                
        # get responses, answers, and results
        prompt_idx = data[0]['prompt_idx']
        responses = [d['response'] for d in data]
        answers = [d['extracted_answer'] if 'extracted_answer' in d else None for d in data]
        if 'mbpp' in load_path:
            results = [d['correct'] for d in data]
        else:
            results = [d['strict_correct'] for d in data]

        all_data[prompt_idx]['responses'].extend(responses)
        all_data[prompt_idx]['answers'].extend(answers)
        all_data[prompt_idx]['results'].extend(results)
        if load_features:
            all_data[prompt_idx]['features'].append(features[feature_type].float())

    return all_data

def maybe_filter_data(
    all_responses: Dict[int, List[str]], 
    all_answers: Dict[int, List[str]], 
    all_results: Dict[int, List[bool]], 
    all_features: Dict[int, torch.Tensor], 
    all_prompts: Dict[int, str],
    subsample_size: int = None,
    min_number_unique_answers: int = None,
    max_n: int = None,
    remove_all_incorrect: bool = False
) -> Tuple[Dict[int, List[str]], Dict[int, List[str]], Dict[int, List[bool]], Dict[int, List[float]], Dict[int, torch.Tensor], Dict[int, str]]:
    """
    Potentially filter data based on either:
        (1) random subsampling of the prompts
        (2) subsampling by number of unique answers
        (3) subsampling the max number of responses per prompt
    """
    # Potentially random subsample
    if subsample_size is not None:
        # take random subset of the keys
        keys = random.sample(list(all_responses.keys()), subsample_size)
        all_responses = {k: all_responses[k] for k in keys}
        all_answers = {k: all_answers[k] for k in keys}
        all_results = {k: all_results[k] for k in keys}
        all_features = {k: all_features[k] for k in keys} if all_features is not None else None
        all_prompts = {k: all_prompts[k] for k in keys}

    # Subsample by number of unique answers
    if min_number_unique_answers is not None:
        all_answers = {k: v for k, v in all_answers.items() if len(set(v)) >= min_number_unique_answers}
        all_responses = {k: v for k, v in all_responses.items() if k in all_answers}
        all_results = {k: v for k, v in all_results.items() if k in all_answers}
        all_features = {k: v for k, v in all_features.items() if k in all_answers} if all_features is not None else None
        all_prompts = {k: v for k, v in all_prompts.items() if k in all_answers}

    # Subsample the max number of responses per prompt
    if max_n is not None:
        all_answers = {k: v[:max_n] for k, v in all_answers.items()}
        all_responses = {k: v[:max_n] for k, v in all_responses.items()}
        all_results = {k: v[:max_n] for k, v in all_results.items()}
        all_features = {k: v[:max_n] for k, v in all_features.items()} if all_features is not None else None

    if remove_all_incorrect:
        all_answers = {k: v for k, v in all_answers.items() if any(all_results[k])}
        all_responses = {k: v for k, v in all_responses.items() if k in all_answers}
        all_results = {k: v for k, v in all_results.items() if k in all_answers}
        all_features = {k: v for k, v in all_features.items() if k in all_answers} if all_features is not None else None
        all_prompts = {k: v for k, v in all_prompts.items() if k in all_answers}

    return all_responses, all_answers, all_results, all_features, all_prompts

def construct_sparse_matrix(features: torch.Tensor, sparse_dim: int) -> torch.Tensor:
    sparse_proj = SparseRandomProjection(sparse_dim, density="auto")
    sparse_proj.fit(features)
    sparse_matrix = sparse_proj.components_
    sparse_matrix_coo = sparse_matrix.tocoo()

    # Convert the row and col lists to numpy arrays and then to a LongTensor (speed up)
    indices = torch.LongTensor(np.array([sparse_matrix_coo.row, sparse_matrix_coo.col]))
    values = torch.FloatTensor(sparse_matrix_coo.data)

    sparse_mat = torch.sparse_coo_tensor(
        indices,
        values,
        [sparse_dim, features.shape[1]]
    ).t()

    return sparse_mat

def load_pool_data(cfg: DictConfig, pool_cfg: DictConfig, dl: BaseDataLoader, remove_all_incorrect: bool = False):
    # Create path to generation files
    policy = cfg.policy.name
    load_path = os.path.join(cfg.io.load_root, 'data', cfg.task.name, policy)

    # Load & extract response data
    response_data = load_response_data(
        load_path=load_path,
        max_num_files=(100 if cfg.debug else None), 
        load_features=pool_cfg.load_features,
        feature_name=pool_cfg.elliptical.feature_name,
        feature_type=pool_cfg.elliptical.feature_type,
        temp=pool_cfg.sampling.temperature,
        top_p=pool_cfg.sampling.top_p,
        min_p=pool_cfg.sampling.min_p
    )
    all_responses = {k: v["responses"] for k, v in response_data.items()}
    all_answers = {k: v["answers"] for k, v in response_data.items()}
    all_results = {k: v["results"] for k, v in response_data.items()}
    if pool_cfg.load_features:
        all_features = {k: v["features"][0] for k, v in response_data.items()}
    all_prompts = {k: dl.questions[k] for k, v in response_data.items()}

    log.info(f'Number of files: {len(all_results)}')

    # Potentially filter data
    data = maybe_filter_data(
        all_responses, 
        all_answers, 
        all_results, 
        all_features if pool_cfg.load_features else None, 
        all_prompts, 
        max_n=cfg.plot.max_n, 
        subsample_size=cfg.plot.subsample_size,
        remove_all_incorrect=remove_all_incorrect,
    )
    all_responses, all_answers, all_results, all_features, all_prompts = data

    # NOTE: I think it's okay if reference model sometimes has no correct responses, since it's just used for sorting?
    if remove_all_incorrect:
        for k, v in all_results.items():
            assert any(v), f"Example {k} has no correct responses"

    # Log basic statistics
    avg_unique_answers = np.mean([len(set(answers)) for answers in all_answers.values()])
    max_unique_answers = max([len(set(answers)) for answers in all_answers.values()])
    min_unique_answers = min([len(set(answers)) for answers in all_answers.values()])
    log.info(f"Avg unique answers: {avg_unique_answers:.1f}, Max unique answers: {max_unique_answers:d}, Min unique answers: {min_unique_answers:d}")

    return {
        "responses": all_responses,
        "answers": all_answers,
        "results": all_results,
        "features": all_features,
        "prompts": all_prompts,
    }