import os
import argparse
import sys
import json
import hashlib
sys.path.append('../src/')

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import siren
import data_utils
import trainer
import argparse

import yaml 


from sklearn.model_selection import ParameterGrid
from tqdm import tqdm

import torch

def run_exp(config, device):
    trainer = Trainer(config, device)
    report = trainer.run()

    return trainer, report

def parse_args():
    parser = argparse.ArgumentParser(description='Run a series of experiments')
    parser.add_argument(
        '--config', type=str, 
        help='path to json config file'
    )
    parser.add_argument(
        '--device', type=str,
        help='which gpu to use'
    )
    return parser.parse_args()

def create_name(d):
    yaml_str = yaml.dump(d, sort_keys=True).encode()
    dhash = hashlib.md5()
    dhash.update(yaml_str)
    hex_hash = dhash.hexdigest()
    return f"exp_{hex_hash}"

if __name__ == "__main__":
    args = parse_args()
    DESTINATION = "/srv/thetis2/gc453/iclr_ntk_runs/"

    with open(args.config) as f:
        config_grid = json.load(f)['configs']

    for config in tqdm(config_grid):
        #CREATE PATH
        exp_name = create_name(config)
        exp_path = os.path.join(DESTINATION, exp_name)
        os.makedirs(exp_path, exist_ok=True)

        #TeST IF RUN EXISTS
        if 'model.pth' in os.listdir(exp_path):
            tqdm.write("SKIPPING CONFIG: MODEL EXISTS")
            continue

        #IF NOT, RUN
        _trainer = trainer.Trainer(config, args.device)
        report = _trainer.run()

        #SAVE
        #save config
        tqdm.write("SAVING SUCCESSFULL RUN")
        config_save_path = os.path.join(exp_path, 'config.json')
        with open(config_save_path, 'w') as f:
            json.dump(config, f, indent=4)

        report_save_path = os.path.join(exp_path, 'report.json')
        with open(report_save_path, 'w') as f:
            json.dump(report, f, indent=4)

        stats_save_path = os.path.join(exp_path, 'order_params.json')
        _trainer.stats.to_json(stats_save_path, orient='records', indent=2)

        #save model
        data_utils.save_exp_dict(exp_path, _trainer.exp_dict)

        torch.save(
            _trainer.model.state_dict(), 
            os.path.join(exp_path, 'model.pth')
        )
        tqdm.write("RUN COMPLETED")
    