import argparse
import pathlib
import sys
import torch
import numpy

import pandas as pd
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor

from recognizers.analysis.print_summary_table import (
    read_trials
)

ARCHITECTURES = ('transformer', 'rnn', 'lstm')

def read_automaton_stats(dirname):
    language_file_name = dirname / 'language.pt'
    try:
        data = torch.load(language_file_name, weights_only=False, map_location=torch.device('cpu'))
        try:
            depth = data['depth']
            states = data['language'].num_states()
            alphabet = data['language'].alphabet_size()
            return states, alphabet, depth
        except AttributeError:
            return None, None, None
    except FileNotFoundError:
        print(f'missing: {language_file_name}', file=sys.stderr)
        return None, None, None

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--base-dir', type=pathlib.Path, required=True)
    parser.add_argument('--num-languages', type=int, required=True)
    parser.add_argument('--parameter-budget', type=int, required=True)
    parser.add_argument('--num-trials', type=int, required=True)
    args = parser.parse_args()

    base_dir = args.base_dir
    num_languages = args.num_languages
    num_trials = args.num_trials
    parameter_budget = args.parameter_budget

    missing_dirs = []

    def get_mean_over_trials(
        parameter_budget,
        architecture,
        validation_set,
        aggregate_scores
    ):
        def generate_language_scores():
            for language_no in range(1, num_languages+1):
                trials_dir = base_dir / 'models' / f'{parameter_budget}' / f'random-star-free-{language_no}' / architecture / 'rec' / validation_set
                trials, trial_missing_dirs = read_trials(trials_dir / str(i) for i in range(1, num_trials+1))
                missing_dirs.extend(trial_missing_dirs)
                scores = numpy.array([t.accuracy for t in trials])
                states, alphabet, depth = read_automaton_stats(base_dir / 'languages' / f'random-star-free-{language_no}')
                if len(scores) > 0 and states is not None and depth is not None and alphabet is not None:
                    yield aggregate_scores(scores), states, alphabet, depth
        stats = list(generate_language_scores())
        if len(stats) > 0:
            return stats
        else:
            return None

    for architecture in ARCHITECTURES:
        print(architecture)
        performances, num_states, alphabet_sizes, depths = zip(*get_mean_over_trials(
            parameter_budget,
            architecture,
            'validation-short',
            lambda scores: numpy.mean(scores)
        ))
        data = {
            "dot_depth": depths,
            "num_states": num_states,
            "alphabet_sizes": alphabet_sizes,
            "performance": performances
        }
        df = pd.DataFrame(data)
        X = df[["dot_depth", "num_states", "alphabet_sizes"]]
        X = sm.add_constant(X)
        y = df["performance"]

        model = sm.OLS(y, X).fit()
        print(model.summary())

        print("Predictor correlations")
        print(df[["dot_depth", "num_states", "alphabet_sizes"]].corr())

        vif_data = pd.DataFrame()
        vif_data["feature"] = X.columns
        vif_data["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]

        print(vif_data)

if __name__ == '__main__':
    main()