import re
import yaml
import copy

import torch
import numpy as np

from scipy.stats import hmean, ks_2samp
# Also import natsort for natural sorting of dictionary keys.
#from natsort import natsorted
# Alternatively, the following import to sort the keys in a dictionary.

def natural_keys(text):
    # Splits the text into a list of strings and integers for natural sorting.
    return [int(c) if c.isdigit() else c.lower() for c in re.split('([0-9]+)', text)]

def find_all_linear_names(model):
    """
    Returns a list of module names corresponding to all instances of torch.nn.Linear.
    Removes 'lm_head' from the returned names (needed for 16-bit).
    """
    cls = torch.nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_params = 0
    for _, param in model.named_parameters():
        all_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_params} || "
        f"trainable%: {100 * trainable_params / all_params}"
    )

def get_model_identifiers_from_yaml(model_family):
    """
    Reads the model configuration file 'config/model_config.yaml' and returns the
    dictionary corresponding to the provided model_family.
    
    Example format of model_config.yaml:
    
    models:
        llama2-7b:
            hf_key: "NousResearch/Llama-2-7b-chat-hf"
            question_start_tag: "[INST] "
            question_end_tag: " [/INST] "
            answer_tag: ""
            start_of_sequence_token: "<s>"
    """
    model_configs  = {}
    with open("config/model_config.yaml", "r") as f:
        model_configs = yaml.load(f, Loader=yaml.FullLoader)
    return model_configs[model_family]

def merge_dicts(a, b):
    """
    Recursively merges dict b into a deep copy of dict a.
    
    For keys present in both dictionaries:
        - If the values are dicts, merge them recursively.
        - If they are lists, the original list is kept (i.e. duplicates are ignored).
        - Otherwise, the value from b overwrites that in a.
    
    Finally, the resulting dictionary is sorted (using natural sort on the keys).
    """
    a_copy = copy.deepcopy(a)
    for key, value in b.items():
        if key in a_copy:
            if isinstance(a_copy[key], dict) and isinstance(value, dict):
                a_copy[key] = merge_dicts(a_copy[key], value)
            elif isinstance(a_copy[key], list) and isinstance(value, list):
                # We keep the list from a_copy (i.e. ignore the duplicate from b)
                a_copy[key] = a_copy[key]
            else:
                a_copy[key] = value  # Overwrite value from b into a_copy
        else:
            a_copy[key] = value
    # Sort the keys using natural order
    #a_copy = {k: a_copy[k] for k in natsorted(a_copy)}
    a_copy = {k: a_copy[k] for k in sorted(a_copy, key=natural_keys)}
    return a_copy

def get_total_len(name, forget_rate):
    """
    Returns the total length for a given evaluation file name and forget_rate.
    """
    if name == 'eval_real_author_wo_options.json':
        return 100
    elif name == 'eval_real_world_wo_options.json':
        return 117
    elif name == 'eval_log.json':
        return 300
    else:
        if forget_rate == 'forget01':
            return 40
        elif forget_rate == 'forget05':
            return 200
        else:
            return 300

def interleave(a, b, size):
    """
    Interleaves two lists a and b in chunks of given size.
    Both lists must have the same length.
    """
    assert len(a) == len(b), "Lists 'a' and 'b' must have the same length for interleaving."
    assert size > 0, "Chunk size must be greater than 0."
    c = []
    for i in range(0, len(a), size):
        c.extend(a[i:i + size])
        c.extend(b[i:i + size])
    return c

def interleave_eval_result_dict(eval_result_dict, forget_rate, large_bsz, num_processes=2):
    """
    For each key (checkpoint) in eval_result_dict, interleaves the two halves of each
    metric's value list in chunks. The chunk size depends on whether the metric name
    contains 'perturb' or 'paraphrase' (uses a smaller batch size).
    
    Note: This function currently assumes num_processes == 2. If the length of a value
    list is odd, the extra element is discarded.
    """
    small_bsz = large_bsz // 4
    for k, v in eval_result_dict.items():
        for metric, value in v.items():
            # Determine the batch size to use for interleaving.
            bsz = small_bsz if ('perturb' in metric or 'paraphrase' in metric) else large_bsz
            total_len = get_total_len(k, forget_rate)
            # Ensure even splitting by discarding an extra element if necessary.
            half = len(value) // 2
            if len(value) % 2 != 0:
                # Optionally, log a warning if an odd number is encountered.
                # print(f"Warning: {k} metric '{metric}' has an odd number of entries; discarding the last element for interleaving.")
                pass
            a = value[:half]
            b = value[half:half + half]
            eval_result_dict[k][metric] = interleave(a, b, bsz)[:total_len]
    return eval_result_dict

def get_model_utility(eval_result_dict):
    """
    Computes model utility metrics from a dictionary of evaluation results.
    
    The function extracts metrics (ROUGE, Probability, Truth Ratio) for different
    evaluation tasks, and then computes an overall 'Model Utility' as the harmonic mean
    of the non-forget metrics.
    """
    eval_task_dict = {
        'eval_real_author_wo_options.json': 'Real Authors',
        'eval_real_world_wo_options.json': 'Real World',
        'eval_log.json': 'Retain',
        'eval_log_forget.json': 'Forget'
    }
    eval_tasks = list(eval_task_dict.keys())
    metrics = ['ROUGE', 'Probability', 'Truth Ratio']

    # Initialize output_result with empty lists (they will be overwritten later).
    output_result = {}
    for eval_task in eval_tasks:
        for metric in metrics:
            output_result[eval_task_dict[eval_task] + ' ' + metric] = []

    for k, v in eval_result_dict.items():
        # Calculate Probability.
        # For keys containing 'eval_log', use the average ground-truth loss.
        if 'eval_log' in k:
            gt_losses = list(eval_result_dict[k]['avg_gt_loss'].values())
            gt_probs = np.exp(-1 * np.array(gt_losses))
            avg_gt_prob = np.mean(gt_probs)
        else:
            true_losses = list(eval_result_dict[k]['avg_gt_loss'].values())
            false_losses = list(eval_result_dict[k]['average_perturb_loss'].values())
            avg_true_prob = np.exp(-1 * np.array(true_losses))
            avg_false_prob = np.exp(-1 * np.array(false_losses))
            # Expand dims so both arrays are 2D before concatenation.
            avg_all_prob = np.concatenate(
                [np.expand_dims(avg_true_prob, axis=-1), np.expand_dims(avg_false_prob, axis=-1)],
                axis=1
            ).sum(-1)
            avg_gt_prob = np.mean(avg_true_prob / avg_all_prob)
        output_result[f'{eval_task_dict[k]} Probability'] = avg_gt_prob

        # Calculate ROUGE.
        rouge_vals = list(eval_result_dict[k]['rougeL_recall'].values())
        avg_rouge = np.array(rouge_vals).mean()
        output_result[f'{eval_task_dict[k]} ROUGE'] = avg_rouge

        # Calculate Truth Ratio.
        data_indices = list(eval_result_dict[k]['avg_paraphrased_loss'].keys())
        avg_paraphrase_vals = []
        avg_perturb_vals = []
        for data_idx in data_indices:
            avg_paraphrase_vals.append(eval_result_dict[k]['avg_paraphrased_loss'][data_idx])
            avg_perturb_vals.append(eval_result_dict[k]['average_perturb_loss'][data_idx])
        avg_paraphrase_np = np.exp(-1 * np.array(avg_paraphrase_vals))
        # Compute the mean of the perturbed values after applying the exponential.
        avg_perturb_np = np.exp(-1 * np.array(avg_perturb_vals)).mean()
        curr_stat_1 = avg_perturb_np / avg_paraphrase_np

        if 'forget' in k:
            paraphrased_perturb_ratio = np.mean(np.minimum(curr_stat_1, 1 / curr_stat_1))
        else:
            paraphrased_perturb_ratio = np.mean(np.maximum(0, 1 - curr_stat_1))
        output_result[f'{eval_task_dict[k]} Truth Ratio'] = paraphrased_perturb_ratio

    # Compute overall model utility as the harmonic mean over non-forget tasks.
    model_utility_cands = []
    for k, v in output_result.items():
        if 'Forget' not in k:
            model_utility_cands.append(v)
    output_result['Model Utility'] = hmean(model_utility_cands)
    return output_result

def get_forget_quality(unlearn_result, retain_result):
    """
    Computes forgetting quality metrics by comparing the 'forget' evaluation results
    from unlearned and retained models.
    
    Returns a dictionary with the KS-test statistic and p-value.
    """
    unlearn_forget_result = unlearn_result['eval_log_forget.json']
    retain_forget_result = retain_result['eval_log_forget.json']

    unlearn_paraphrase_vals = np.array(list(unlearn_forget_result['avg_paraphrased_loss'].values()))
    unlearn_perturb_vals = np.array(list(unlearn_forget_result['average_perturb_loss'].values()))
    # If the array is multi-dimensional, take the mean along the last axis.
    unlearn_perturb_mean = unlearn_perturb_vals.mean(axis=-1)

    retain_paraphrase_vals = np.array(list(retain_forget_result['avg_paraphrased_loss'].values()))
    retain_perturb_vals = np.array(list(retain_forget_result['average_perturb_loss'].values()))
    retain_perturb_mean = retain_perturb_vals.mean(axis=-1)

    unlearn_truth_ratio = np.exp(unlearn_perturb_mean - unlearn_paraphrase_vals)
    retain_truth_ratio = np.exp(retain_perturb_mean - retain_paraphrase_vals)

    test_res = ks_2samp(unlearn_truth_ratio, retain_truth_ratio)
    return {
        'Forget Quality': test_res.pvalue,
        'KS Test PVal Forget': test_res.pvalue,
        'KS Test Forget': test_res.statistic
    }

def add_dataset_index(dataset):
    """
    Adds an 'index' column to the dataset. Assumes that the dataset object has
    a method 'add_column' (e.g. a HuggingFace Dataset).
    """
    indexing = np.arange(len(dataset))
    dataset = dataset.add_column('index', indexing)
    return dataset