import argparse
import numpy as np
import torch
import os, json
from datetime import datetime
import pickle
from collections import OrderedDict
import wandb
import time
import ipdb
from timeit import default_timer as timer

RESULTS_LOCAL_DIRNAME = ''
os.environ['RESULTS_DIRNAME'] = ''
os.environ['WANDB_PROJECT_NAME'] = ''

with open('') as f:
    default_args = json.load(f, object_pairs_hook=OrderedDict)

parser = argparse.ArgumentParser()
for k, v in default_args.items():
    parser.add_argument(f'--{k}', default=v, type=type(v))

with open('') as f:
    wandb_init = json.load(f, object_pairs_hook=OrderedDict)


def update_wandb_config(mydict, run_config, key_prefix=''):
    for key, value in mydict.items():
        if isinstance(value, dict):
            update_wandb_config(value, run_config, key_prefix=key_prefix + str(key) + '.')
        else:
            run_config.update({f'{key_prefix}{key}': f'{value}'})


def setup_wandb(setup_dict):
    print("wandb setup starts...")

    os.environ["WANDB_API_KEY"] = setup_dict['key']
    os.environ["WANDB_MODE"] = setup_dict['mode']  # online or offline

    run = wandb.init(project=setup_dict['project_name'], \
                     entity=setup_dict['org'])
    # we'll use the run_id as the name and correlate it with filesystem organization
    run.name = setup_dict['name']
    run.save()
    # for now we are overriding wandb config with our config
    # but this can also go the other way around if it's easier
    update_wandb_config(setup_dict, run.config)

    return run


# decorator to calculate duration
# taken by any function.
def calculate_time(func):
    # added arguments inside the inner1,
    # if function takes any arguments,
    # can be added like this.
    def inner1(*args, **kwargs):
        # storing time before function execution
        begin = timer()  # time.time()
        returned_value = func(*args, **kwargs)

        # storing time after function execution
        end = timer()  
        if func.__name__ == '__init__':
            if kwargs['wandb_run'] is not None:
                kwargs['wandb_run'].summary['time: preconditioner ' + func.__name__] = end - begin
            else:
                pass
        else:
            if args[0].wandb_run is not None:
                if func.__name__ == '__call__':
                    name = 'Kmat_xs_xbatch_mult'
                else:
                    name = func.__name__
                tim_track_list = args[0].time_track_dict[name]
                args[0].time_track_dict[name][1] = \
                    (tim_track_list[0] * tim_track_list[1] + end - begin) / (tim_track_list[0] + 1)
                args[0].time_track_dict[name][0] = (tim_track_list[0] + 1)
            else:
                pass
        return returned_value

    return inner1


def dict_diff(dict1, dict2):
    dict3 = {}
    for k, v in dict1.items():
        try:
            if dict2[k] != v:
                dict3[k] = v
        except KeyError:
            pass
    return dict3


def save_results_and_args(results, args, dirname, filename):
    if not os.path.exists(dirname): os.makedirs(dirname)

    filename = os.path.join(dirname, filename)
    with open(f'{filename}.json', 'w') as f_jsn:
        json.dump(args, f_jsn)

    with open(f'{filename}.pickle', 'wb') as f_pkl:
        pickle.dump(results, f_pkl)


def float_x(data):
    '''Set data array precision.'''
    if torch.is_tensor(data):
        data = data.float()
    else:
        data = np.float32(data)
    return data


def tensor(data, device, dtype=None, release=False):
    tensor = torch.tensor(data, dtype=dtype,
                          requires_grad=False).to(device)


def Yaccu(y,method = 'argmax'):

    if type(y) != type(torch.tensor([0])):
        y = torch.tensor(y)

    if y.size()[1] == 1:
        y_s = torch.zeros_like(y)
        y_s[torch.where(y > 0)[0]] = 1
        y_s[torch.where(y < 0)[0]] = -1
    elif method == 'argmax':
        y_s = torch.argmax(y, dim=1)
    elif method == 'top5':
        (values, y_s) = torch.topk(y,5,dim=1)

    return y_s


square_loss = lambda yhat, y: 0.5 * (yhat - y).pow(2).mean(0)
square_grad = lambda yhat, y: (yhat - y)
square_hess = lambda yhat, y: torch.ones(yhat.shape)
logistic_loss = lambda yhat, y: (1 + (-yhat * y).exp()).log().mean(0)
logistic_grad = lambda yhat, y: -y * torch.sigmoid(yhat * y)
logistic_hess = lambda yhat, y: torch.sigmoid(yhat * y) * (1 - torch.sigmoid(yhat * y))
predict = lambda Kmat, alpha: Kmat @ alpha
accuracy = lambda yhat, y: ((yhat.sign() == y) * 1.).mean(0)
