import os
import pathlib
import re
import collections
import functools
import inspect
import sys
import pytest
from typing import List
from pytorch_lightning import Callback
from models.GPT2_Model_valid import GPT2Valid
import pandas as pd
import math

def lcs_of_1D_tensor(X, Y):
    # modified code from https://www.geeksforgeeks.org/longest-common-substring-dp-29/
    # Uses DP to calculate length longest common "substring"
    n, m = X.shape[0], Y.shape[0]
    LCSuff = [[0 for k in range(n+1)] for l in range(m+1)]
 
    result = 0
    for i in range(m + 1):
        for j in range(n + 1):
            if (i == 0 or j == 0):
                LCSuff[i][j] = 0
            elif (X[i-1] == Y[j-1]):
                LCSuff[i][j] = LCSuff[i-1][j-1] + 1
                result = max(result, LCSuff[i][j])
            else:
                LCSuff[i][j] = 0
    return result

def ngram_of_1D_tensor(X, n):
    grams = [tuple(X[i:i+n].tolist()) for i in range(X.shape[0] - n + 1)]
    return grams

class ExitCodeError(Exception):
    pass


def sh(x):
    if os.system(x):
        raise ExitCodeError()


def simple_parse_args_string(args_string):
    """
    Parses something like
        args1=val1,arg2=val2
    Into a dictionary
    """
    args_string = args_string.strip()
    if not args_string:
        return {}
    arg_list = args_string.split(",")
    args_dict = {}
    for arg in arg_list:
        k, v = arg.split("=")
        args_dict[k] = v
    return args_dict

def join_iters(iters):
    for iter in iters:
        yield from iter


def chunks(iter, n):
    arr = []
    for x in iter:
        arr.append(x)
        if len(arr) == n:
            yield arr
            arr = []

    if arr: yield arr

def group(arr, fn):
    res = collections.defaultdict(list)

    for ob in arr:
        res[fn(ob)].append(ob)

    return list(res.values())

def general_detokenize(string):
    string = string.replace(" n't", "n't")
    string = string.replace(" )", ")")
    string = string.replace("( ", "(")
    string = string.replace("\" ", "\"")
    string = string.replace(" \"", "\"")
    string = re.sub(r" (['.,])", r"\1", string)
    return string


def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):
    """
    - context_len allows for a rolling window context, allowing each prediction window to potentially
      condition on some context
    :param token_list: list
        List of tokens to be PREDICTED
    :param max_seq_len: int
        max_seq_len of model (or max_seq_len we want to use)
    :param context_len: int
        Amount of desired token context for prediction. Needs to be at least 1.
    :param prefix_token: token
        Dummy token like <eos> so the first token has something to condition on
    :return: generator
        Generator of tuples
            (input_tokens, pred_tokens)
        Note: Score only the last len(pred_tokens) logits of the LM
    """
    assert 1 <= context_len <= max_seq_len
    if not token_list:
        return
    # +1 offset, going from input->preds
    pred_len = max_seq_len - context_len + 1
    predicted = 0

    # Special handling for first window: predict all tokens
    first_seq_len = min(max_seq_len, len(token_list))
    yield (
        [prefix_token] + token_list[:first_seq_len - 1],
        token_list[:first_seq_len]
    )
    predicted += first_seq_len

    while predicted < len(token_list):
        window_pred_len = min(len(token_list) - predicted, pred_len)
        window_end = predicted + window_pred_len

        yield (
            token_list[window_end - max_seq_len - 1:window_end - 1],
            token_list[window_end - window_pred_len:window_end],
        )
        predicted += window_pred_len

def make_disjoint_window(pair):
    """ Takes output from get_rolling_token_windows and makes the context not overlap with the continuation """

    a, b = pair

    return a[:-(len(b) - 1)], b

class Reorderer:
    def __init__(self, arr, fn):
        self.size = len(arr)
        arr = list(enumerate(arr))
        arr = group(arr, lambda x: fn(x[1]))
        arr = [
            ([y[0] for y in x], x[0][1]) for x in arr
        ]
        arr.sort(key=lambda x: fn(x[1]))

        self.arr = arr


    def get_reordered(self):
        return [x[1] for x in self.arr]

    def get_original(self, newarr):
        res = [None] * self.size
        cov = [False] * self.size

        for (inds, _), v in zip(self.arr, newarr):
            for ind in inds: 
                res[ind] = v
                cov[ind] = True

        assert all(cov)

        return res

def positional_deprecated(fn):
    """
    A decorator to nudge users into passing only keyword args (`kwargs`) to the 
    wrapped function, `fn`.
    """
    @functools.wraps(fn)
    def _wrapper(*args, **kwargs):
        if len(args) != 1 if inspect.ismethod(fn) else 0: 
            print(f"WARNING: using {fn.__name__} with positional arguments is "
                "deprecated and will be disallowed in a future version of "
                "lm-evaluation-harness!")
        return fn(*args, **kwargs)
    return _wrapper

@positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
    """
    Search upward in the directory tree to a maximum of three layers
    to find and return the package root (containing the 'tests' folder)
    """
    cur_path = start_path.resolve()
    max_layers = 3
    for _ in range(max_layers):
        if (cur_path / 'tests' / 'test_version_stable.py').exists():
            return cur_path
        else:
            cur_path = cur_path.parent.resolve()
    raise FileNotFoundError(f"Unable to find package root within {max_layers} upwards" +\
        f"of {start_path}")

@positional_deprecated
def run_task_tests(task_list: List[str]):
    """
    Find the package root and run the tests for the given tasks
    """
    package_root = find_test_root(start_path=pathlib.Path(__file__))
    task_string = ' or '.join(task_list)
    args = [f'{package_root}/tests/test_version_stable.py', f'--rootdir={package_root}', '-k', f'{task_string}']
    sys.path.append(str(package_root))
    pytest_return_val = pytest.main(args)
    if pytest_return_val:
        raise ValueError(f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}")


class MetricTracker(Callback):

    def __init__(self, run_name, validation_only):
        self.df = None
        self.run_name = run_name
        self.validation_only = validation_only

    def on_fit_end(self, trainer, module):
        print(trainer.logged_metrics)
        elogs = trainer.logged_metrics # access it here
        elogs = {k: [v.item()] for k, v in elogs.items()}
        new_df = pd.DataFrame(elogs)
        train_cols = [col for col in new_df if 'train' in col]
        new_df = new_df.drop(columns=train_cols)
        if self.df is not None:
            self.df = pd.concat([self.df, new_df])
        else:
            self.df = new_df
        
        cols = self.df.columns.values.tolist()
        ppl_col = 'wikitext/loss'
        cols.remove(ppl_col)
        for col in cols:
            self.df[col] = self.df[col].apply(lambda x: x * 100)
        self.df[ppl_col] = self.df[ppl_col].apply(math.exp)
        self.df.to_csv(f'csv_out/main/lr-vary/{self.run_name}.csv', index=False)


    def on_validation_epoch_end(self, trainer, module):
        if isinstance(module, GPT2Valid):
            elogs = trainer.logged_metrics # access it here
            elogs = {k: [v.item()] for k, v in elogs.items()}
            self.df = pd.DataFrame(elogs)
            if self.validation_only:
                cols = self.df.columns.values.tolist()
                ppl_col = 'wikitext/loss'
                cols.remove(ppl_col)
                for col in cols:
                    self.df[col] = self.df[col].apply(lambda x: x * 100)
                self.df[ppl_col] = self.df[ppl_col].apply(math.exp)
                self.df.to_csv(f'csv_out/main/lr-vary/{self.run_name}.csv', index=False)
