from typing import Any

import os
from random import seed
from time import time
import yaml

from tqdm import tqdm

from do_shap.frontiers import *


SEED = 123
MIN_K = 5
MAX_K = 20
PS = [.1, .2, .3, .4, .5, .6, .7, .8, .9]
REPS = 30
NPERMS = 1000


if __name__ == '__main__':
    # Experiment with variable K and p
    seed(SEED)  # for reproducibility

    results: list[dict[str, Any]] = []
    for k in range(MIN_K, MAX_K + 1):
        nperms = NPERMS  # use a value depending on K; look at the cache formula
        for p in PS:
            for exact in [True]:  # [False, True]:
                print(k, p, exact)
                for rep in tqdm(range(REPS)):
                    d: dict[str, Any] = dict(
                        k=k, p=p, exact=exact, nperms=nperms, rep=rep
                    )
                    results.append(d)
                    graph = sample_dag_all_ancestors_not_all_parents(k, p)

                    d['edges'] = ' '.join(
                        str(node)
                        for edge in graph.edges
                        for node in edge
                    )

                    # Run FRA
                    X = graph.X[:]  # make a copy

                    t1 = time()

                    fr1 = FR1(graph)
                    if exact:
                        for comb in parts_of(X):
                            fr1.run(comb)
                    else:
                        for nperm in range(nperms):
                            shuffle(X)  # new permutation
                            for i in range(len(X) + 1):
                                fr1.run(X[:i])

                    t2 = time()

                    X2 = [FR2.encode_x(i) for i in X]

                    fr2 = FR2(graph)
                    if exact:
                        for comb in parts_of(X2):
                            fr2.run(sum(comb))
                    else:
                        for _ in range(nperms):
                            shuffle(X2)  # new permutation
                            for i in range(len(X2) + 1):
                                fr2.run(sum(X2[:i]))

                    t3 = time()

                    d['frontiers'] = len(fr1.fr_cache)
                    d['values'] = len(fr1.v_cache)

                    d['time_FR1'] = t2 - t1
                    d['time_FR2'] = t3 - t2

    with open(os.path.join(
        cast(str, os.getenv('RESULTS_DIR')),
        'fra',
        'frontiers.yaml'
    ), 'w') as f:
        yaml.safe_dump(results, f)
