import numpy as np
import random
import torch
import math

import scipy as sp
from scipy import integrate
import warnings

def save_state_dicts(dicts, task: str, init: str):
    """
    Save the first few iterations of the model being trained.
    :param dicts: list of state dictionaries for pytorch models
    :param task: abreviated name of the task
    :param init: abreviated name of the initialization/architecture
    """
    for i, state in enumerate(dicts):
        torch.save(state, 'networks/' + init + '_' + task + str(i) + '.pt')

def copy_memory_seqs(T: int, num: int):
    """
    Copy memory task: input 10 symbols, input a null symbol for a specified time lag, input special symbol, read out last 10 symbols and try to match first 10 symbols.
    :param T: time lag
    :type T: int
    :param num: total number of trajectories
    :return: 80% training trajectories 20% testing trajectories saved to numpy files
    """

    input_seq = np.zeros((num, T + 20, 10))
    target = np.zeros((num, 10), dtype='int64')

    for j in range(num):

        for i in range(10):
            index = random.randint(0, 7)
            input_seq[j, i, index] = 1.0
            target[j, i] = index

        for i in range(10, T + 9):
            input_seq[j, i, 8] = 1.0

        input_seq[j, T + 9, 9] = 1.0

        for i in range(T + 10, T + 20):
            input_seq[j, i, 8] = 1.0

    num_train = 8*num//10
    num_test = 2*num//10

    train = input_seq[0 : num_train, :, :]
    train_targ = target[0 : num_train, :]
    test = input_seq[num_train : num_train + num_test, :, :]
    test_targ = target[num_train : num_train + num_test, :]

    T += 20

    np.save('input_data/copy_mem_train_' + str(num_train) + '_' + str(T) +'.npy', train)
    np.save('input_data/copy_mem_test_' + str(num_test) + '_' + str(T) +'.npy', test)
    np.save('input_data/copy_mem_train_targ_' + str(num_train) + '_' + str(T) +'.npy', train_targ)
    np.save('input_data/copy_mem_test_targ_' + str(num_test) + '_' + str(T) +'.npy', test_targ)

def temporal_ordering_seqs(T: int, num: int):
    """
    Temporal Ordering task: Given a binary sequence of some length, classify it based on the first and middle characters.
    :param T: trajectory length
    :type T: int
    :param num: total number of trajectories
    :return: 80% training trajectories 20% testing trajectories saved to numpy files
    """

    input_seq = np.zeros((num, T, 2))
    target = np.zeros(num, dtype='int64')

    for i in range(num):
        for j in range(T):
            idx = random.randint(0, 1)
            input_seq[i, j, idx] = 1

    for i in range(num):
        target[i] = 2*input_seq[i, 0, 0] + input_seq[i, T//2, 0]

    num_train = 8*num//10
    num_test = 2*num//10

    train = input_seq[0 : num_train]
    test = input_seq[num_train : num_train + num_test]
    train_targ = target[0 : num_train]
    test_targ = target[num_train : num_train + num_test]

    np.save('input_data/temp_ord_train_' + str(num_train) + '_' + str(T) +'.npy', train)
    np.save('input_data/temp_ord_test_' + str(num_test) + '_' + str(T) +'.npy', test)
    np.save('input_data/temp_ord_train_targ_' + str(num_train) + '_' + str(T) +'.npy', train_targ)
    np.save('input_data/temp_ord_test_targ_' + str(num_test) + '_' + str(T) +'.npy', test_targ)


def optimize_qstar_sigmaw_sigmab(L):
    warnings.simplefilter("ignore")
    tanh = np.tanh
    sech = lambda x: 1.0 / np.cosh(x)
    g = lambda h, q: np.sqrt(1 / (2 * np.pi)) * np.exp(-(h ** 2) / 2) * ((sech(np.sqrt(q) * h)) ** 2) ** 2

    c = (L - 1) * (L / (L - 1)) ** L
    C = c / (1 + c)

    def objective(QQ):
        DPhi, _ = sp.integrate.quad(g, -np.inf, np.inf, args=(QQ))
        return 0.5 * (DPhi - C) ** 2

    res = sp.optimize.minimize_scalar(objective, args=(), method='bounded', bounds=[0, 3], tol=None,
                                      options={'maxiter': 1000})
    Q = res.x
    sigma_w, sigma_b = optimize_sigmaw_sigmab(Q)
    return Q, sigma_w, sigma_b


def optimize_sigmaw_sigmab(Q):
    warnings.simplefilter("ignore")
    tanh = np.tanh
    sech = lambda x: 1.0 / np.cosh(x)
    g = lambda h, q: np.sqrt(1 / (2 * np.pi)) * np.exp(-(h ** 2) / 2) * ((sech(np.sqrt(q) * h)) ** 2) ** 2
    f = lambda h, q: np.sqrt(1 / (2 * np.pi)) * np.exp(-(h ** 2) / 2) * (tanh(np.sqrt(q) * h)) ** 2
    gamma, abserr2 = integrate.quad(f, -np.inf, np.inf, args=(Q))
    GAMMA, abserr2 = integrate.quad(g, -np.inf, np.inf, args=(Q))
    sigma_w = GAMMA ** -.5
    sigma_b = np.sqrt(Q - sigma_w ** 2 * gamma)

    return sigma_w, sigma_b
