#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import pathlib
import json
import torch
from folders import folders


def list_models(folder=folders['models'], models=None, show_pars=False):
    if models is None:
        models = os.listdir(folder)
        if '.DS_Store' in models:
            models.remove('.DS_Store')
        models.sort()
    elif isinstance(models, str):
        models = [models]
    for m in models:
        print('------------------------------------------------------------')
        print(m)
        print('------------------------------------------------------------')
        versions = os.listdir(os.path.join(folder, m))
        if '.DS_Store' in versions:
            versions.remove('.DS_Store')
        versions.sort(key=lambda v: int(v[7:]))
        for version in versions:
            print('     ' + version)
            p_file = os.path.join(folder, m, version, 'hparam.json')
            if os.path.isfile(p_file) and show_pars:
                f = open(p_file, 'r')
                string = f.read()
                last_opt = '{' + string.split(sep='{')[-1]
                print(last_opt)


def add_neigh(folder=folders['models'], models=None):
    if models is None:
        models = os.listdir(folder)
        if '.DS_Store' in models:
            models.remove('.DS_Store')
        models.sort()
    elif isinstance(models, str):
        models = [models]
    for m in models:
        print('------------------------------------------------------------')
        print(m)
        print('------------------------------------------------------------')
        versions = os.listdir(os.path.join(folder, m))
        if '.DS_Store' in versions:
            versions.remove('.DS_Store')
        versions.sort(key=lambda v: int(v[7:]))
        for version in versions:
            print('     ' + version)
            p_file = os.path.join(folder, m, version, 'hparam.json')
            if os.path.isfile(p_file):
                f = open(p_file, 'r')
                string = f.read()
                f.close()
                last_opt = '{' + string.split(sep='{')[-1]
                hparam = json.loads(last_opt)
                if not 'neigh' in hparam.keys():
                    w_file = os.path.join(folder, m, version, 'pars.pth')
                    if os.path.isfile(w_file):
                        p_dict = torch.load(w_file, map_location='cpu')
                        if 'log_c' in p_dict.keys():
                            hparam['neigh'] = 2 * p_dict['log_c'].shape[0]
                        elif '0.log_c' in p_dict.keys():
                            hparam['neigh'] = 2 * p_dict['0.log_c'].shape[0]
                        elif 'layer1.log_c' in p_dict.keys():
                            hparam['neigh'] = 2 * p_dict['layer1.log_c'].shape[0]
                        else:
                            print('Problem! No log_c found!')
                    print(hparam)
                    string_new = '{'.join(string.split(sep='{')[:-1]
                                          + [json.dumps(hparam,indent=2)[1:]])
                    f = open(p_file, 'w')
                    f.write(string_new)
                    f.close()
