import torch
import argparse
import time
import numpy as np
import pickle
import os
import math
import random
from functools import partial

import torch.multiprocessing
from torch.multiprocessing import Process, Manager

import pyro
import pyro.distributions as dist
from pyro.infer import HMC, MCMC, NUTS

from sklearn.linear_model import LinearRegression
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

# This is required both to get AMD CPUs to work well, but also
# to disable the aggressive multi-threading of the underlying
# linear algebra libraries, which interferes with our multiprocessing
# with PyTorch
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['MKL_DEBUG_CPU_TYPE'] = '5'
os.environ['MKL_SERIAL'] = 'YES'
os.environ['OMP_NUM_THREADS'] = '1'

#export CUDA_VISIBLE_DEVICES=""
#export MKL_DEBUG_CPU_TYPE=5
#export MKL_SERIAL=YES; export OMP_NUM_THREADS=1

def one_hot_vectors(num_classes):
    one_hots = []
    for i in range(num_classes):
        a = [0.0]*num_classes
        a[i] = 1.0
        one_hots.append(np.array(a))
    return one_hots
    
def stagedUTM_forward(pnv_S, 
        pnv_Y, 
        symbol_weights, 
        state_weights, 
        direction_weights,
        num_symbols,
        num_states,
        num_directions,
        work_tape_max,
        theta_tensor,
        num_tuples):
    """
    pnv_S: tensor of shape (batch, num_states)
    pnv_Y: tensor of shape (batch, tape_size, num_symbols)
    """
    batch_shape = pnv_S.shape[0]

    one_hots_symbols = one_hot_vectors(num_symbols)
    one_hots_states = one_hot_vectors(num_states)
    one_hots_directions = one_hot_vectors(num_directions)
    one_hots = [one_hots_symbols, one_hots_states, one_hots_directions]

    # In the following a is a batch index, s is always a tuple index, i is always a symbol index and
    # j is always a state index, k a direction index

    # Take the external tensor product of the symbol under the head and the state
    tensor_Y0_S = torch.einsum('ai,aj->aij', pnv_Y[:,work_tape_max,:], pnv_S)
    #print("tensor_Y0_S")
    #print(tensor_Y0_S[0,:,:])

    # Compute the vector lambda (called here lada), which is the contraction of theta tensor_Y0_S
    # see p.16 of (nstagutm). Note theta_tensor has shape [num_tuples,num_symbols,num_states]
    lada = torch.einsum('sij,aij->as', theta_tensor, tensor_Y0_S) # shape [num_tuples]
    #print("lada")
    #print(lada)

    # Compute the vector
    #
    # aux = ( \prod_{l=2}^N (1-\lambda_l), \prod_{l=3}^N(1-\lambda_l), ..., (1-\lambda_N),1)
    #
    # where N = num_tuples. See the bottom of p.15 of (nstagutm)

    batch_ones = torch.tensor([1.0],dtype=torch.float)
    batch_ones = torch.unsqueeze(batch_ones,dim=0)
    batch_ones = batch_ones.repeat((batch_shape,1)) # has shape [batch,1]

    aux = batch_ones

    # aux starts off [1], then after the first run through the loop it is [2] etc.
    # until eventually it is [num_tuples]
    for i in range(num_tuples-1):
      aux = torch.cat([torch.einsum('a,az->az', 1-lada[:,i+1], aux), 
                       batch_ones],dim=1) 

    #print("aux")
    #print(aux)

    # lambda_aux is the elementwise product of the vector in the final two lines
    # of p.15 of (nstagutm), it has shape [num_tuples]
    lambda_aux = lada * aux

    # torch.nn.functional.softmax(-,dim=1)
    # form the tensor P_nv(\hat{s}) as given in the third to last display on p.15 of (nstagutm)
    pnv_shat = torch.einsum('si,as->ai', symbol_weights, lambda_aux)

    # form the tensor P_nv(\hat{q}) as given at the top of p.16 of (nstagutm)
    pnv_qhat = torch.einsum('sj,as->aj', state_weights, lambda_aux)

    # form the tensor P_nv(\hat{d}) as given in the middle of p.16 of (nstagutm)
    pnv_dhat = torch.einsum('sk,as->ak', direction_weights, lambda_aux)

    # New state of state tape
    prod_all_oneminuslambda = (1-lada[:,0]) * aux[:,0] # shape [batch]

    # compute prod_all_oneminuslambda * pnv_S
    oml_pnv_S = torch.einsum('a,aj->aj', prod_all_oneminuslambda, pnv_S)

    pnv_S = pnv_qhat + oml_pnv_S # shape [batch,num_states]
    pnv_S = pnv_S / torch.sum(pnv_S,axis=1,keepdim=True) # Renormalise

    # New state of work tape
    new_tensors = []

    # compute prod_all_oneminuslambda * pnv_Y[:,work_tape_max,:]
    oml_pnv_Y = torch.einsum('a,ai->ai',prod_all_oneminuslambda,pnv_Y[:,work_tape_max,:])
    A = pnv_shat + oml_pnv_Y

    for j in range(2*work_tape_max + 1):
      #print("In position {0} --".format(j))
      if( j == work_tape_max + 1 ): # case u = 1
        left_sym = A
      else:
        if( j > 0 ):
          left_sym = pnv_Y[:,j-1,:]
        else:
          left_sym = torch.tensor(one_hots_symbols[0],dtype=torch.float)
          left_sym = torch.unsqueeze(left_sym,dim=0)
          left_sym = left_sym.repeat((batch_shape,1))

      if( j == work_tape_max - 1 ): # case u = -1
        right_sym = A
      else:
        if( j < 2*work_tape_max ):
          right_sym = pnv_Y[:,j+1,:]
        else:
          right_sym = torch.tensor(one_hots_symbols[0],dtype=torch.float)
          right_sym = torch.unsqueeze(right_sym,dim=0)
          right_sym = right_sym.repeat((batch_shape,1))

      if( j == work_tape_max ): # case u = 0
        stay_sym = A
      else:
        stay_sym = pnv_Y[:,j,:]

      pnv_Mv = pnv_dhat + torch.einsum('a,q->aq',prod_all_oneminuslambda, torch.tensor([0.0, 0.0, 1.0],dtype=torch.float))
      #print("pnv_Mv")
      #print(pnv_Mv[0,:])

      # left_sym, right_sym, stay_sym have indices i, and when stacked
      # the result has indices ki
      nearby_sym = torch.stack([left_sym,right_sym,stay_sym],dim=1)
      new = torch.einsum('ak,aki->ai', pnv_Mv,nearby_sym)

      #print("nearby_sym")
      #print(nearby_sym[0,:])

      #print("new")
      #print(new[0,:])

      # Renormalise
      # new = new / torch.sum(new)
      new_tensors.append(new)

    # New state of work tape
    pnv_Y = torch.stack(new_tensors,dim=1)

    return pnv_S, pnv_Y

def TM_descr_detectA():
    [s_blank, s_A, s_B] = range(0,3)
    [q_reject, q_accept] = range(0,2)
    [d_L, d_R, d_S] = range(0,3)

    TM_descr = [[s_blank,q_reject,s_blank,q_reject,d_S], # see a blank, freeze
              [s_blank,q_accept,s_blank,q_accept,d_S],
              [s_A,q_reject,s_A,q_accept,d_S], # see an A, accept and freeze
              [s_A,q_accept,s_A,q_accept,d_S],
              [s_B,q_reject,s_B,q_reject,d_R], # see a B, move right
              [s_B,q_accept,s_B,q_accept,d_S]] 
              
    return TM_descr
    
def TM_descr_parityCheck():
    [s_blank, s_A, s_B, s_X] = range(0,4)
    [q_reject, q_accept, q_getNextAB, q_getNextA, q_getNextB, q_gotoStart] = range(0,6)
    [d_L, d_R, d_S] = range(0,3)

    TM_descr1     = [[0,0,0,0,2],[1,0,1,0,2],[2,0,2,0,2],[3,0,3,0,2],[0,1,0,1,2]]
    TM_descr2     = [[1,1,1,1,2],[2,1,2,1,2],[3,1,3,1,2],[0,2,0,1,2],[1,2,3,4,1]]
    TM_descr3     = [[2,2,3,3,1],[3,2,3,2,1],[0,3,0,0,2],[1,3,3,5,0],[2,3,2,3,1]]
    TM_descr4     = [[3,3,3,3,1],[0,4,0,0,2],[1,4,1,4,1],[2,4,3,5,0],[3,4,3,4,1]]
    TM_descr5     = [[1,5,1,5,0],[2,5,2,5,0],[3,5,3,5,0],[0,5,0,2,1]]
    
    TM_descr = TM_descr1 + TM_descr2 + TM_descr3 + TM_descr4 + TM_descr5
    return TM_descr

def model(X, Y, beta, args):
    """
    X: tensor of shape [num_samples, tape_size, num_symbols]
    Y: tensor of shape [num_samples, num_states]
    """
    num_data = X.shape[0]
    
    symbol_weights = pyro.sample("symbol_weights", dist.Dirichlet(args.prior_alpha * torch.ones((args.num_tuples, args.num_symbols)))) 

    state_weights = pyro.sample("state_weights", dist.Dirichlet(args.prior_alpha * torch.ones((args.num_tuples, args.num_states))))
    
    direction_weights = pyro.sample("direction_weights", dist.Dirichlet(args.prior_alpha * torch.ones((args.num_tuples, args.num_directions))))

    # initial states, tiled across the batch dimension
    one_hots_states = one_hot_vectors(args.num_states)
        
    init_states = torch.tensor(one_hots_states[args.init_state],dtype=torch.float)
    init_states = torch.unsqueeze(init_states,dim=0)
    init_states = init_states.repeat((num_data,1))
    
    # Let Pyro know the samples are independent
    with pyro.plate("plate", len(X)):
        pnv_S = init_states
        pnv_Y = X

        for i in range(args.time_steps):
            pnv_S, pnv_Y = stagedUTM_forward(pnv_S, pnv_Y, symbol_weights, state_weights, direction_weights,args.num_symbols, args.num_states, args.num_directions,args.work_tape_max,
args.theta_tensor, args.num_tuples)

        # Pyro's Categorical is a wrapper on torch.Categorical (http://docs.pyro.ai/en/0.2.1-release/_modules/pyro/distributions/torch.html#Categorical)
        # torch.Categorical states that if you specify probs= then the results will be normalised to 1.0
        # Hence if we replace probs by probs^\beta then Torch will automatically renormalise
        y = pyro.sample("Y", dist.Categorical(probs=torch.pow(pnv_S,beta)), obs=Y)
        return y
    
# helper function for HMC inference
def run_inference(model, args, X, Y, beta, beta_num, samples):
    start = time.time()
    kernel = NUTS(model, adapt_step_size=True, 
                target_accept_prob=args.target_accept_prob,
                jit_compile=args.jit)
    mcmc = MCMC(kernel, num_samples=args.num_samples, warmup_steps=args.num_warmup)
    mcmc.run(X, Y, beta, args)
    print("\n[beta = {}]".format(beta))
    mcmc.summary(prob=0.5)

    torch.save(mcmc.get_samples(), '{}/mcmc_beta{}_samples.pt'.format(args.path, beta_num))
    torch.save(time.time() - start, '{}/mcmc_beta{}_time_secs.pt'.format(args.path, beta_num))
    samples[beta_num] = mcmc.get_samples()

def generateXY_parityCheck(min_seq_length,max_seq_length):
    [s_blank, s_A, s_B, s_X] = range(0,4)
    [q_reject, q_accept, q_getNextAB, q_getNextA, q_getNextB, q_gotoStart] = range(0,6)
    
    seq_length = random.randint(min_seq_length, max_seq_length)
    seq = np.array([random.randint(s_A,s_B) for _ in range(seq_length)])

    seq_c = 0
    for j in seq:
        if( j == s_A ):
            seq_c += 1
        elif( j == s_B ):
            seq_c -= 1

    if( seq_c == 0 ):
        target = q_accept
    else:
        target = q_reject    

    return seq, target

def generateXY_detectA(min_seq_length,max_seq_length):
    [s_blank, s_A, s_B] = range(0,3)
    [q_reject, q_accept] = range(0,2)
    
    seq_length = random.randint(min_seq_length, max_seq_length)
    seq = np.array([random.randint(s_A,s_B) for _ in range(seq_length)]) 
    
    if( s_A in seq ):
        target = q_accept
    else:
        target = q_reject
    
    return seq, target
            
def get_data_true(args):
    one_hots_symbols = one_hot_vectors(args.num_symbols)
    one_hots_states = one_hot_vectors(args.num_states)
    
    tapesize=2*args.work_tape_max + 1

    # initial work tape and correct answers
    initial_tape_np = np.zeros((args.num_data, tapesize, args.num_symbols))
    correct_outputs_np = np.zeros((args.num_data, args.num_states))
    correct_outputs_int_np = []

    if( args.problem == "detectA" ):
        generateXY = partial(generateXY_detectA,args.min_seq_length,args.max_seq_length)
    elif( args.problem == "parityCheck" ):
        generateXY = partial(generateXY_parityCheck,args.min_seq_length,args.max_seq_length)
        
    for a in range(args.num_data):
        seq, target = generateXY()
                            
        enc_seq = np.zeros((tapesize,args.num_symbols))
    
        s_blank = 0
        for j in range(tapesize):
          z = j - args.work_tape_max
          if( z >= 0 and z < len(seq)):
            enc_seq[j,:] = one_hots_symbols[seq[z]]
          else:
            enc_seq[j,:] = one_hots_symbols[s_blank]
                                        
        initial_tape_np[a,:,:] = enc_seq
        correct_outputs_np[a,:] = one_hots_states[target]
        correct_outputs_int_np.append(target)
        
    
    initial_tape = torch.tensor(initial_tape_np, dtype=torch.float)
    #correct_outputs = torch.tensor(correct_outputs_np, dtype=torch.float)
    correct_outputs_int = torch.tensor(correct_outputs_int_np, dtype=torch.int)
    return initial_tape, correct_outputs_int

def expected_nll_posterior(samples, X, Y, args):
    num_data = X.shape[0]

    one_hots_states = one_hot_vectors(args.num_states)
        
    # initial states, tiled across the batch dimension
    init_states = torch.tensor(one_hots_states[args.init_state],dtype=torch.float)
    init_states = torch.unsqueeze(init_states,dim=0)
    init_states = init_states.repeat((num_data,1))
        
    nll = []
    for r in range(args.num_samples):
        symbol_weights = samples['symbol_weights'][r]
        state_weights = samples['state_weights'][r]
        direction_weights = samples['direction_weights'][r]
    
        # Let Pyro know the samples are independent
        with pyro.plate("plate", len(X)):
            pnv_S = init_states.clone().detach()
            pnv_Y = X.clone().detach()

            for i in range(args.time_steps):
                pnv_S, pnv_Y = stagedUTM_forward(pnv_S, pnv_Y, symbol_weights, state_weights, direction_weights,args.num_symbols, args.num_states, args.num_directions,args.work_tape_max,
    args.theta_tensor, args.num_tuples)

            ydist = dist.Categorical(probs=pnv_S)
            nll += [-ydist.log_prob(Y).sum()]
    
    return sum(nll)/args.num_samples

def main(args):
    path = args.path
    n = args.chain_temp

    X, Y = get_data_true(args)

    betas = np.linspace(1 / np.log(n) * (1 - 1 / np.sqrt(2 * np.log(n))),
                        1 / np.log(n) * (1 + 1 / np.sqrt(2 * np.log(n))), args.num_betas)

    # do inference
    manager = Manager()
    samples = manager.dict()
    jobs = []
    for i in range(len(betas)):
        p = Process(target=run_inference, args=(model, args, X, Y, betas[i], i, samples))
        jobs.append(p)
        p.start()
    for p in jobs:
        p.join()

    estimates = [expected_nll_posterior(samples[i], X, Y, args) for i in range(len(samples))]
    regr = LinearRegression(fit_intercept=True)
    one_on_betas = (1 / betas).reshape(args.num_betas, 1)
    regr.fit(one_on_betas, estimates)
    score = regr.score(one_on_betas, estimates)
    b_ols = regr.intercept_
    m_ols = regr.coef_[0]

    torch.save(X, '{}/data_X.pt'.format(path))
    torch.save(Y, '{}/data_Y.pt'.format(path))
    torch.save(m_ols, '{}/rlct_estimate.pt'.format(path))
    print('RLCT estimate {} with r2 coeff {}'.format(m_ols, score))
    
    plt.figure()
    plt.title("E^beta_w[nL_n(w)] against 1/beta for single dataset")
    plt.scatter(1/betas, np.array(estimates))
    plt.plot(1/betas, [m_ols * x + b_ols for x in 1/betas], label='ols')
    plt.legend(loc='best')
    plt.savefig("{}/linfit.png".format(path))


if __name__ == "__main__":
    random.seed()
    
    parser = argparse.ArgumentParser(description="RLCT_HMC_symmetric")
    parser.add_argument("--experiment-id", nargs="?")
    parser.add_argument("--save-prefix", nargs="?")
    parser.add_argument("--num-samples", nargs="?", default=100000, type=int)
    parser.add_argument("--num-warmup", nargs='?', default=30000, type=int)
    parser.add_argument("--num-data", nargs='?', default=1000, type=int)
    parser.add_argument("--prior-alpha", nargs='?', default=1.0, type=float)
    parser.add_argument("--target-accept-prob", nargs='?', default=0.9, type=float)
    parser.add_argument("--min-seq-length", nargs='?', default=4, type=int)
    parser.add_argument("--max-seq-length", nargs='?', default=4, type=int)    
    parser.add_argument("--num-betas", default=5, type=int)
    parser.add_argument("--chain-temp",default=0, type=int)
    parser.add_argument("--problem",default="detectA")
    parser.add_argument("--jit", action='store_true', default=False)
    parser.add_argument("--cuda", action='store_true', default=False, help="run this in GPU")
    
    args = parser.parse_args()
    args_dict = vars(args)
    print(args_dict)

    args_filename = args.save_prefix + '/' + args.experiment_id + '-args.pickle'
    
    # by default we center betas around 1/log(n)
    if( args.chain_temp == 0 ): args.chain_temp = args.num_data
    
    if( args.problem == "detectA" ):
        args.num_symbols = 3 # symbols identified with 0, ..., num_symbols - 1 where 0 is blank
        args.num_states = 2
        args.num_directions = 3 # 0,1,2 (left,right,stay)
        args.work_tape_max = 11 # The relative positions allowed on the work tape are -tape_max, ..., tape_max.
        args.time_steps = 10 # The number of time steps to run (this means TM time steps)
        args.init_state = 0 # initial state on the state tape
        TM_descr = TM_descr_detectA()  
    elif( args.problem == "parityCheck" ):
        # NOTE: this is the original parityCheck not parityCheckPrime, i.e. there is
        # no special symbol to the left of the initial head position
        # We allow sequences of any length, obviously odd ones are rejected
        # We restrict to sequences of length <= 6 where the true TM takes <= 40 steps
        args.num_symbols = 4 # symbols identified with 0, ..., num_symbols - 1 where 0 is blank
        args.num_states = 6
        args.num_directions = 3 # 0,1,2 (left,right,stay)
        args.work_tape_max = 7 # The relative positions allowed on the work tape are -tape_max, ..., tape_max.
        args.time_steps = 42 # The number of time steps to run (this means TM time steps)
        args.init_state = 2 # initial state on the state tape
        TM_descr = TM_descr_parityCheck()
        
    args.num_tuples = len(TM_descr)

    # See p.16 of (nstagutm)
    # This tensor encodes the tuples on the description tape
    theta_shape = [args.num_tuples,args.num_symbols,args.num_states]
    theta = np.zeros(theta_shape)

    for i in range(0,args.num_tuples):
      theta[i,TM_descr[i][0],TM_descr[i][1]] = 1.0

    args.theta_tensor = torch.tensor(theta,dtype=torch.float)

    # create path
    args.path = args.save_prefix + '/{}'.format(args.experiment_id)
    if not os.path.exists(args.path):
        os.makedirs(args.path)

    # save simulation setting
    torch.save(args, '{}/args.pt'.format(args.path))

    # for GPU see https://github.com/pyro-ppl/pyro/blob/dev/examples/baseball.py
    # work around the error "CUDA error: initialization error"
    # see https://github.com/pytorch/pytorch/issues/2517
    torch.multiprocessing.set_start_method("spawn")
    
    # work around with the error "RuntimeError: received 0 items of ancdata"
    # see https://discuss.pytorch.org/t/received-0-items-of-ancdata-pytorch-0-4-0/19823
    torch.multiprocessing.set_sharing_strategy("file_system")
        
    main(args)