import torch as t
from typing import Callable, List
import pickle as pk
from scipy import stats
import numpy as np

import hashlib
import pickle as pk
import json
import os
def hash_dict(d):
    jsond = json.dumps(dict(sorted(d.items())))
    sha512 = hashlib.sha256()
    sha512.update(jsond.encode('utf-8'))
    return sha512.hexdigest()
import fcntl
class Lock:
    def __init__(self, lockfile="./lock.file"):
        self.lockfile=lockfile
    def __enter__ (self):
        self.fp = open(self.lockfile)
        fcntl.flock(self.fp.fileno(), fcntl.LOCK_EX)
    def __exit__ (self, _type, value, tb):
        fcntl.flock(self.fp.fileno(), fcntl.LOCK_UN)
        self.fp.close()
lockfile="./lock.file"

def get_ds_to_results(pkl_files: List, filter_params: Callable) -> dict:
    """
    pkl_files: list of pickle file dirs which we need to process
    filter_params: used to filter which results to keep

    `pkl_files` has the form: {hash: {metrics: [...],
                                      params: {..}}
    where `hash` is a hash of the parameters. this is so that the code running experiments
    will not introduce duplicate results. we ignore the hash once processesing the results
    """
    ds_to_results = {}
    for pkl_file in pkl_files:
        dataset_name = pkl_file.stem
        with Lock(lockfile=lockfile):
            with open(pkl_file, 'rb') as f:
                hash_to_results = pk.load(f)

        hash_to_results = {k: v for k, v in hash_to_results.items()}
        for _hash, result in hash_to_results.items():
            if 'Pi' in result['params']: result['model'] = 'gcdkm'
            elif 'nhidden' in result['params']: result['model'] = 'gcn'
            else: raise ValueError('model not recognized')

            """skip gcdkm results that dont' have the following parameters"""
            result['keep'] = filter_params(result['params'])


            for metric in ['val_acc', 'test_acc']:
                val_accs = np.array([x[metric] for x in result['metrics']])
                result[metric] = val_accs.mean()

                result[f'std_{metric}'] = val_accs.std()
                result[f'se_{metric}'] = stats.sem(val_accs)
                if result['metrics'][0]['split'] != 'public': # \implies several data splits
                    # assume splits are given by 'test-$SPLIT-$SEED'
                    _accs = {}
                    for m in result['metrics']:
                        split_ix =  m['split'].split('-')[-1]
                        if split_ix in _accs: _accs[split_ix].append(m[metric])
                        else: _accs[split_ix] = [m['val_acc']]
                    _val_accs = np.array([np.mean(_accs[split_ix]) for split_ix in _accs.keys()])
                    result[f'split_std_{metric}'] = _val_accs.std()
                    result[f'split_se_{metric}'] = stats.sem(_val_accs)
                else:
                    result[f'split_std_{metric}'] = result[f'std_{metric}']
                    result[f'split_se_{metric}'] = result[f'se_{metric}']

        results = [v for k, v in hash_to_results.items() if v['keep']]

        ## remove duplicate runs with same (seed, split)
        for r in results:
            _metrics = []
            _seen = set()
            for m in r['metrics']:
                k = (m['seed'], m['split'])
                if k not in _seen: _metrics.append(m) # only
                _seen.add(k)
            r['metrics'] = _metrics

        try:
            results = sorted(results, key=lambda x: float(x['params']['dof']))
        except:
            ## try and sort it, it doesn't matter if it fails
            pass

        ds_to_results[dataset_name] = results
    return ds_to_results


def check_ds_to_results(ds_to_results):
    len_cora = len(ds_to_results['cora'])
    for ds in ds_to_results.keys():
        assert len(ds_to_results[ds]) == len_cora, f"len(ds_to_results[{ds}]) != len_cora, got {len(ds_to_results[ds])}"
        x = ds_to_results[ds]
        if len(x) == 0:
            print(f"no results for {ds}")
        r0 = x[0]

        for r in ds_to_results[ds]:
            assert len(r0['metrics']) == len(r['metrics']), f"{ds}: len(r0['metrics']) != len(r['metrics']), got {len(r0['metrics'])} and {len(r['metrics'])}"

def read_pkl_file(fname):
    with Lock(lockfile=lockfile):
        with open(fname, 'rb') as f:
            return pk.load(f)

def write_pkl_file(x, fname):
    with Lock(lockfile=lockfile):
        with open(fname, 'wb') as f:
            pk.dump(x, f)

def maybe_read_pkl_file(fname):
    if os.path.exists(fname):
        return read_pkl_file(fname)
    else:
        return dict()

def mean_std_to_latex(mean, std, *extra, bold=False):
    if not bold:
        return f"${mean:.1f} \\pm {std:.1f}$" + "".join(extra)
    else:
        return "$\\mathbf{" + f"{mean:.1f} \\pm {std:.1f}" + "}$" + "".join(extra)
def mean_to_latex(mean, *extra, bold=False):
    if not bold:
        return f"${mean:.1f}$" + "".join(extra)
    else:
        return "$\\mathbf{" + f"{mean:.1f}" + "}$" + "".join(extra)
def means_stds_to_latex(means_and_stds):
    best_mean = max(means_and_stds, key=lambda x: x[0])[0]
    res = []
    for xs in means_and_stds:
        mean = xs[0]; std = xs[1]; extra = xs[2:]
        if mean == -1. and std == -1.:
            res.append("---")
        else:
            bold = round(mean, 2) >= round(best_mean, 2)
            if std < 0.:
                res.append(mean_to_latex(mean, *extra, bold=bold))
            else:
                res.append(mean_std_to_latex(mean, std, *extra, bold=bold))
    return res

def get_best(ds_to_results, param_keys):
    def select_best(rs):
        best = rs[0]
        for r in rs[1:]:
            if r['val_acc'] > best['val_acc']:
                best = r
        return best
    def filter_model(rs, model):
        return [r for r in rs if r['params']['model'] == model]
    data = {
        ds_name: [select_best(filter_model(rs, 'kipf')),
                  select_best(filter_model(rs, 'res')),
                  select_best(filter_model(rs, 'kipfres'))]
        for ds_name, rs in ds_to_results.items()
    }

    for ds_name, bests in data.items():
        print("=====")
        for best in bests:
            print({k: v for k,v in best['params'].items() if k in param_keys})
            print(ds_name, f"val_acc: {best['val_acc']:.4f}, test_acc: {best['test_acc']:.4f}")
        print("=====")