import sys
import os
import csv
import json

import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind, mannwhitneyu, ks_2samp

def prepare_model(model_name, cache_dir, quant=None):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    # pad token
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    tokenizer.model_max_length = 512

    if quant is None:
        model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, trust_remote_code=True).cuda()
    elif quant == "fp16":
        model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
    elif quant == "8bit":
        quant_config = BitsAndBytesConfig(
            load_in_8bit=True,
        )
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            cache_dir=cache_dir,
            trust_remote_code=True,
            quantization_config=quant_config
        )

    print("Model loaded")
    return model, tokenizer


def find_empty(data, message=""):
    empty_prefs = 0
    empty_suffs = 0
    empty_preds = 0
    total = 0
    for i, ex in enumerate(data):
        if ex['prefix'].strip() == '':
            empty_prefs += 1
        if ex['suffix'].strip() == '':
            empty_suffs += 1
        if ex['predic'].strip() == '':
            empty_preds += 1
        total += 1

    if len(message) > 0:
        print(f"\n{message}:", file=sys.stderr)
    print(f"Empty prefixes: {empty_prefs} ({empty_prefs / total:.1%})", file=sys.stderr)
    print(f"Empty suffixes: {empty_suffs} ({empty_suffs / total:.1%})", file=sys.stderr)
    print(f"Empty predictions: {empty_preds} ({empty_preds / total:.1%})", file=sys.stderr)
    print("\n", file=sys.stderr)

    return empty_prefs, empty_suffs, empty_preds, total

def write_results_to_csv(output_path, results):
    file_exists = os.path.isfile(output_path)

    with open(output_path, mode='a', newline='', encoding='utf-8') as file:
        writer = csv.DictWriter(file, fieldnames=results.keys())

        if not file_exists:
            writer.writeheader()

        writer.writerow(results)

def pad_left(seqences, padding_value):
    max_length = max(seq.size(0) for seq in seqences)
    return torch.stack([
        torch.nn.functional.pad(seq, (max_length - seq.size(0), 0), value=padding_value)
        for seq in seqences
    ])

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable (%): {100 * trainable_params / all_param}"
    )

def make_classifier_dataset(data, col0, col1, split_ratio=0.5):
    n = int(len(data) * split_ratio)

    texts = data[col0][:n] + data[col1][:n]
    labels = [0] * n + [1] * n
    dataset = Dataset.from_dict({"text": texts, "label": labels}).shuffle(seed=4321)

    return dataset

def make_eval_dataset(data, col0, col1, split_ratio=0.5):
    n = int(len(data) * split_ratio)

    texts = data[col0][n:]
    labels = [0] * n
    dataset0 = Dataset.from_dict({"text": texts, "label": labels})

    texts = data[col1][n:]
    labels = [1] * n
    dataset1 = Dataset.from_dict({"text": texts, "label": labels})

    return dataset0, dataset1

def span_tokenize(text, tokenizer):
    encoding = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
    for start, end in encoding["offset_mapping"]:
        yield (start, end)

def stats_diff_train_test_val(train, test, val, alternative=None, verbosity=1):
    """
    Perform statistical tests to compare train, validation sets against test set.
    Args:
        train (iterable): Training set values.
        test (iterable): Test set values.
        val (iterable): Validation set values.
        alternative (str): Alternative hypothesis for the one sided tests. Options are, 'greater' or 'less'. If None, only two sided tests are performed.
        verbosity (int): Level of verbosity for output. 0 - no output, 1 - print to stderr, 2 - print to stdout.
    Returns:
        dict: A dictionary containing p-values for the tests performed.
    """
    def printv(*args, **kwargs):
        if verbosity == 1:
            print(*args, file=sys.stderr, **kwargs)
        elif verbosity == 2:
            print(*args, **kwargs)

    test_results = {}

    # Two sided t test
    t_stat_train, p_value_train = ttest_ind(train, val, alternative='two-sided')
    printv(f"Two sided t test for train vs val")
    print(f"T-statistic: {t_stat_train:.4f}, P-value: {p_value_train:.4f}")
    t_stat_test, p_value_test = ttest_ind(test, val, alternative='two-sided')
    printv(f"Two sided t test for test vs val")
    printv(f"T-statistic: {t_stat_test:.4f}, P-value: {p_value_test:.4f}")
    test_results['t_test_two_sided'] = {
        'train_vs_val': p_value_train.item(),
        'test_vs_val': p_value_test.item()
    }
    printv("."*64, flush=True)

    # Two sided Mann-Whitney U test
    t_stat_train, p_value_train = mannwhitneyu(train, val, alternative='two-sided')
    printv(f"Two sided U test for train vs val")
    printv(f"Statistic: {t_stat_train:.4f}, P-value: {p_value_train:.4f}")
    t_stat_test, p_value_test = mannwhitneyu(test, val, alternative='two-sided')
    printv(f"Two sided U test for test vs val")
    printv(f"Statistic: {t_stat_test:.4f}, P-value: {p_value_test:.4f}")
    test_results['u_test_two_sided'] = {
        'train_vs_val': p_value_train.item(),
        'test_vs_val': p_value_test.item()
    }
    printv("."*64, flush=True)

    # Two sided KS test
    result = ks_2samp(train, val, alternative='two-sided')
    ks_stat_train, p_value_train = result.statistic, result.pvalue
    printv(f"Two sided KS test for train vs val")
    printv(f"Statistic: {ks_stat_train:.4f}, P-value: {p_value_train:.4f}")
    result = ks_2samp(test, val, alternative='two-sided')
    ks_stat_test, p_value_test = result.statistic, result.pvalue
    printv(f"Two sided KS test for test vs val")
    printv(f"Statistic: {ks_stat_test:.4f}, P-value: {p_value_test:.4f}")
    test_results['ks_test_two_sided'] = {
        'train_vs_val': p_value_train.item(),
        'test_vs_val': p_value_test.item()
    }
    printv("."*64, flush=True)

    if alternative in ['greater', 'less']:
        # One sided t test
        t_stat_train, p_value_train = ttest_ind(train, val, alternative=alternative)
        printv(f"One sided t test for train vs val")
        printv(f"T-statistic: {t_stat_train:.4f}, P-value: {p_value_train:.4f}")
        t_stat_test, p_value_test = ttest_ind(test, val, alternative=alternative)
        printv(f"One sided t test for test vs val")
        printv(f"T-statistic: {t_stat_test:.4f}, P-value: {p_value_test:.4f}")
        test_results['t_test_one_sided'] = {
            'train_vs_val': p_value_train.item(),
            'test_vs_val': p_value_test.item()
        }
        printv("."*64, flush=True)

        # One sided Mann-Whitney U test
        t_stat_train, p_value_train = mannwhitneyu(train, val, alternative=alternative)
        printv(f"One sided U test for train vs val")
        printv(f"Statistic: {t_stat_train:.4f}, P-value: {p_value_train:.4f}")
        t_stat_test, p_value_test = mannwhitneyu(test, val, alternative=alternative)
        printv(f"One sided U test for test vs val")
        printv(f"Statistic: {t_stat_test:.4f}, P-value: {p_value_test:.4f}")
        test_results['u_test_one_sided'] = {
            'train_vs_val': p_value_train.item(),
            'test_vs_val': p_value_test.item()
        }
        printv("."*64, flush=True)

        # One sided KS test
        result = ks_2samp(train, val, alternative=alternative)
        ks_stat_train, p_value_train = result.statistic, result.pvalue
        printv(f"One sided KS test for train vs val")
        printv(f"Statistic: {ks_stat_train:.4f}, P-value: {p_value_train:.4f}")
        result = ks_2samp(test, val, alternative=alternative)
        ks_stat_test, p_value_test = result.statistic, result.pvalue
        printv(f"One sided KS test for test vs val")
        printv(f"Statistic: {ks_stat_test:.4f}, P-value: {p_value_test:.4f}")
        test_results['ks_test_one_sided'] = {
            'train_vs_val': p_value_train.item(),
            'test_vs_val': p_value_test.item()
        }
        printv("."*64, flush=True)
    elif alternative is not None:
        raise ValueError(f"Unknown alternative hypothesis: {alternative}")
    
    return test_results

def stats_diff_train_val_test_val(train, train_val, test, test_val, alternative=None, verbosity=1):
    """
    Perform statistical tests to compare train, train_val, test, and test_val sets.
    
    Args:
        train (iterable): Values from the train set.
        train_val (iterable): Values from the validation set corresponding to train.
        test (iterable): Values from the test set.
        test_val (iterable): Values from the validation set corresponding to test.
        alternative (str, optional): Alternative hypothesis for one-sided tests. Options: 'greater', 'less'. If None, only two-sided tests are performed.
        verbosity (int, optional): Level of verbosity for output. 0 - no output, 1 - print to stderr, 2 - print to stdout.
    
    Returns:
        dict: Dictionary containing p-values for t-test, Mann-Whitney U test, and Kolmogorov-Smirnov test (two-sided and optionally one-sided).
    """
    def printv(*args, **kwargs):
        if verbosity == 1:
            print(*args, file=sys.stderr, **kwargs)
        elif verbosity == 2:
            print(*args, **kwargs)

    test_results = {}

    # Two sided t test
    t_stat_train, p_value_train = ttest_ind(train, train_val, alternative='two-sided')
    printv(f"Two sided t test for train vs val")
    print(f"T-statistic: {t_stat_train:.4f}, P-value: {p_value_train:.4f}")
    t_stat_test, p_value_test = ttest_ind(test, test_val, alternative='two-sided')
    printv(f"Two sided t test for test vs val")
    printv(f"T-statistic: {t_stat_test:.4f}, P-value: {p_value_test:.4f}")
    test_results['t_test_two_sided'] = {
        'train_vs_val': p_value_train.item(),
        'test_vs_val': p_value_test.item()
    }
    printv("."*64, flush=True)

    # Two sided Mann-Whitney U test
    t_stat_train, p_value_train = mannwhitneyu(train, train_val, alternative='two-sided')
    printv(f"Two sided U test for train vs val")
    printv(f"Statistic: {t_stat_train:.4f}, P-value: {p_value_train:.4f}")
    t_stat_test, p_value_test = mannwhitneyu(test, test_val, alternative='two-sided')
    printv(f"Two sided U test for test vs val")
    printv(f"Statistic: {t_stat_test:.4f}, P-value: {p_value_test:.4f}")
    test_results['u_test_two_sided'] = {
        'train_vs_val': p_value_train.item(),
        'test_vs_val': p_value_test.item()
    }
    printv("."*64, flush=True)

    # Two sided KS test
    result = ks_2samp(train, train_val, alternative='two-sided')
    ks_stat_train, p_value_train = result.statistic, result.pvalue
    printv(f"Two sided KS test for train vs val")
    printv(f"Statistic: {ks_stat_train:.4f}, P-value: {p_value_train:.4f}")
    result = ks_2samp(test, test_val, alternative='two-sided')
    ks_stat_test, p_value_test = result.statistic, result.pvalue
    printv(f"Two sided KS test for test vs val")
    printv(f"Statistic: {ks_stat_test:.4f}, P-value: {p_value_test:.4f}")
    test_results['ks_test_two_sided'] = {
        'train_vs_val': p_value_train.item(),
        'test_vs_val': p_value_test.item()
    }
    printv("."*64, flush=True)

    if alternative in ['greater', 'less']:
        # One sided t test
        t_stat_train, p_value_train = ttest_ind(train, train_val, alternative=alternative)
        printv(f"One sided t test for train vs val")
        printv(f"T-statistic: {t_stat_train:.4f}, P-value: {p_value_train:.4f}")
        t_stat_test, p_value_test = ttest_ind(test, test_val, alternative=alternative)
        printv(f"One sided t test for test vs val")
        printv(f"T-statistic: {t_stat_test:.4f}, P-value: {p_value_test:.4f}")
        test_results['t_test_one_sided'] = {
            'train_vs_val': p_value_train.item(),
            'test_vs_val': p_value_test.item()
        }
        printv("."*64, flush=True)

        # One sided Mann-Whitney U test
        t_stat_train, p_value_train = mannwhitneyu(train, train_val, alternative=alternative)
        printv(f"One sided U test for train vs val")
        printv(f"Statistic: {t_stat_train:.4f}, P-value: {p_value_train:.4f}")
        t_stat_test, p_value_test = mannwhitneyu(test, test_val, alternative=alternative)
        printv(f"One sided U test for test vs val")
        printv(f"Statistic: {t_stat_test:.4f}, P-value: {p_value_test:.4f}")
        test_results['u_test_one_sided'] = {
            'train_vs_val': p_value_train.item(),
            'test_vs_val': p_value_test.item()
        }
        printv("."*64, flush=True)

        # One sided KS test
        result = ks_2samp(train, train_val, alternative=alternative)
        ks_stat_train, p_value_train = result.statistic, result.pvalue
        printv(f"One sided KS test for train vs val")
        printv(f"Statistic: {ks_stat_train:.4f}, P-value: {p_value_train:.4f}")
        result = ks_2samp(test, test_val, alternative=alternative)
        ks_stat_test, p_value_test = result.statistic, result.pvalue
        printv(f"One sided KS test for test vs val")
        printv(f"Statistic: {ks_stat_test:.4f}, P-value: {p_value_test:.4f}")
        test_results['ks_test_one_sided'] = {
            'train_vs_val': p_value_train.item(),
            'test_vs_val': p_value_test.item()
        }
        printv("."*64, flush=True)
    elif alternative is not None:
        raise ValueError(f"Unknown alternative hypothesis: {alternative}")
    
    return test_results

def safe_float_array(arr):
    return np.array([float(x) for x in arr if str(x).replace('.', '', 1).isdigit()])

# metrics
def perplexity(loss_list):
    '''
    This function takes a list of lists and returns the perplexity of each list
    input:
        loss_list: a list of lists

    output:
        perplexity: the perplexity of each list
    '''
    perplexity = []
    for entry in loss_list:
        entry = safe_float_array(entry)
        # calculate the mean of each list
        mean = sum(entry)/len(entry)
        # ppl is the exponent of the mean
        ppl = torch.exp(torch.tensor(mean)).item()
        perplexity.append(ppl)

    return perplexity

loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# Function adapted from external source
def raw_values_batch(model, tokenizer, example_list):
    '''
    This function takes a list of strings and returns the loss values for each token in the string
    input:
        model: the language model
        tokenizer: the tokenizer
        example_list: a list of strings

    output:
        loss_list:  a list of lists. 
                    Each list contains the loss values for each token in the string

    '''
    # Fix: avoid huge max_length
    max_length = min(getattr(tokenizer, "model_max_length", 2048), 2048)
    input_ids = tokenizer(example_list, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
    
    if model.device.type == "cuda":
        input_ids = {k: v.cuda() for k, v in input_ids.items()}
    
    # forward pass with no grad
    with torch.no_grad():
        outputs = model(**input_ids)
    
    labels = input_ids["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100

    # shift the labels
    shifted_labels = labels[..., 1:].contiguous().view(-1)

    # shift the logits
    shifted_logits = outputs.logits[..., :-1, :].contiguous()
    shifted_logits = shifted_logits.view(-1, shifted_logits.size(-1))

    loss = loss_fct(shifted_logits, shifted_labels)

    # reshape the loss to the original shape
    loss = loss.view(labels.size(0), labels.size(1) - 1)

    # now remove the 0 values and create loss as a list of lists
    loss_list = loss.tolist()
    
    for i,entry in enumerate(loss_list):
        # remove the 0 values
        entry = [x for x in entry if x != 0]
        loss_list[i] = entry
    
    # if any list is empty, remove it
    loss_list = [entry for entry in loss_list if len(entry) > 0]

    return loss_list

def save_metrics4(split_name, metrics_name, metrics_values, out_dir):
    """
    Saves metrics to {split_name}_metrics.json as a JSON object.
    If file exists, appends/updates the metric; if not, creates new file.
    metrics_name: metric name (str)
    metrics_values: list of values (e.g. list(bertscore) or tensor)
    """
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"{split_name}_metrics.json")
    metrics_dict = {}
    # If file exists, load existing metrics
    if os.path.exists(out_path):
        with open(out_path, "r", encoding="utf-8") as f:
            try:
                metrics_dict = json.load(f)
            except Exception:
                metrics_dict = {}
    # Convert to a list of floats (handles tensors, numpy arrays, and Python lists)
    if hasattr(metrics_values, 'tolist'):
        values = metrics_values.tolist()
    else:
        values = list(metrics_values)
    values = [float(x) for x in values]
    metrics_dict[metrics_name] = values
    # Save everything as a single JSON object
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(metrics_dict, f, ensure_ascii=False)
    print(f"Saved metrics for {split_name} to {out_path}")

def save_metrics(results_file, dict_metrics):
    """
    Saves metrics to results_file as a JSON object.
    dict_metrics: {"metric_name": [list of values (float, tensor, numpy)]}
    If file exists, updates/appends metrics; if not, creates new file.
    """
    # Load existing metrics if file exists
    if os.path.exists(results_file):
        with open(results_file, "r", encoding="utf-8") as f:
            try:
                metrics_dict = json.load(f)
            except Exception:
                metrics_dict = {}
    else:
        metrics_dict = {}
    # Update/add metrics from dict_metrics
    for metrics_name, metrics_values in dict_metrics.items():
        if hasattr(metrics_values, 'tolist'):
            values = metrics_values.tolist()
        else:
            values = list(metrics_values)
        values = [float(x) for x in values]
        metrics_dict[metrics_name] = values
    # Save everything as a JSON object
    with open(results_file, "w", encoding="utf-8") as f:
        json.dump(metrics_dict, f, ensure_ascii=False)
    print(f"Saved metrics to {results_file}")
