#
# Imports
#
import causal._causal, causal.models
import argparse, pdb, os, sys, time, unittest
import numpy     as np
import torch



#
# Test C code.
#
class TestCCode(unittest.TestCase):
    def test_sample_mlp(self):
        # Create fake arguments for model.
        a = argparse.Namespace()
        a.seed            = 1
        a.temperature     = 1.0
        a.batch_size      = 256
        a.num_vars        = 8
        a.num_cats        = 2
        a.num_parents     = 5
        a.hidden_truth    = None
        a.hidden_learn    = None
        a.graph           = None
        a.summary         = False
        
        # Create model.
        np.random.seed(a.seed)
        torch.random.manual_seed(a.seed)
        model = causal.models.CategoricalWorld(a)
        
        # Set constants of problem
        BS    = a.batch_size
        M     = model.M
        Hgt   = model.Hgt
        N     = np.asarray([a.num_cats]*M, dtype=np.int32)
        Ns    = int(N.sum())
        alpha = 0.1
        
        # Compute "aligned" constants
        BSa  = (BS +63)&~63
        Ma   = (M  +63)&~63
        Hgta = (Hgt+63)&~63
        Nsa  = (Ns +63)&~63
        
        # Allocate arrays, nicely-strided and zero-padded.
        W0     = np.zeros((M,Nsa,Hgta), dtype=np.float32)[:,:Ns,:Hgt]
        dW0    = np.zeros((M,Nsa,Hgta), dtype=np.float32)[:,:Ns,:Hgt]
        B0     = np.zeros((M,Hgta),     dtype=np.float32)[:,:Hgt]
        dB0    = np.zeros((M,Hgta),     dtype=np.float32)[:,:Hgt]
        W1     = np.zeros((Nsa,Hgta),   dtype=np.float32)[:Ns,:Hgt]
        dW1    = np.zeros((Nsa,Hgta),   dtype=np.float32)[:Ns,:Hgt]
        B1     = np.zeros((Nsa,),       dtype=np.float32)[:Ns]
        dB1    = np.zeros((Nsa,),       dtype=np.float32)[:Ns]
        block  = np.zeros_like(N)
        batch  = np.zeros((Ma,BSa),     dtype=np.int32)  [:M,:BS]
        config = np.ones ((Ma,Ma),      dtype=np.float32)[:M,:M]
        out    = np.zeros((Ma,BSa),     dtype=np.float32)[:M,:BS]
        
        # Copy over weights, clear config diagonal.
        torch.from_numpy(config).diagonal().zero_()
        torch.from_numpy(W0).copy_(model.W0slow.permute(0,2,3,1).reshape(M,Ns,Hgt))
        torch.from_numpy(B0).copy_(model.B0slow)
        torch.from_numpy(W1).copy_(model.W1slow.reshape(Ns,Hgt))
        torch.from_numpy(B1).copy_(model.B1slow.reshape(Ns))
        
        #
        # Invoke C function for sample().
        #
        self.assertIsNone(causal._causal.sample_mlp(W0,B0,W1,B1,N,config,batch,alpha))
        print("Batch of samples:")
        np.savetxt(sys.stdout, batch.T, fmt="%d", delimiter=" ")
        print("Expectations of variables:")
        for i,r in enumerate(batch):
            sys.stdout.write("{:.3f} ".format(r.astype(np.float32).mean()))
        print()
        print()
        
        #
        # Invoke PyTorch code for forward logprob().
        #
        batchpy = torch.zeros((BS,M,a.num_cats), dtype=torch.float32)
        batchpy.scatter_(2, torch.from_numpy(batch).t().long().unsqueeze(2), 1)
        outpy   = model.logprob(batchpy, torch.from_numpy(config))[0]
        print("outpy:")
        print(outpy)
        
        #
        # Invoke C function for forward logprob().
        #
        causal._causal.logprob_mlp(W0,B0,W1,B1,N,np.zeros_like(N),batch,config,out,
                                   None,None,None,None,alpha,a.temperature)
        print("outavx:")
        print(torch.as_tensor(out.T))
        self.assertLess((outpy-torch.as_tensor(out.T)).abs().max().item(), 1e-6)
        print()
        print()
        
        #
        # Invoke PyTorch code for forward+backward logprob().
        #
        model.logprob(batchpy, torch.from_numpy(config))[0].mean(0).sum().backward()
        
        #
        # Invoke C function for forward+backward logprob().
        #
        causal._causal.logprob_mlp(W0,B0,W1,B1,N,np.zeros_like(N),batch,config,out,
                                   dW0,dB0,dW1,dB1,alpha,a.temperature)
        dB1c = torch.as_tensor(dB1)
        dB1p = model.B1slow.grad.reshape(Ns)
        dW1c = torch.as_tensor(dW1)
        dW1p = model.W1slow.grad.reshape(Ns,Hgt)
        dB0c = torch.as_tensor(dB0)
        dB0p = model.B0slow.grad
        dW0c = torch.as_tensor(dW0)
        dW0p = model.W0slow.grad.permute(0,2,3,1).reshape(M,Ns,Hgt)
        self.assertLess((dB1c-dB1p).abs().max().item(), 1e-6)
        self.assertLess((dW1c-dW1p).abs().max().item(), 1e-6)
        self.assertLess((dB0c-dB0p).abs().max().item(), 1e-6)
        self.assertLess((dW0c-dW0p).abs().max().item(), 1e-6)
        
        #
        # Timing
        #
        smpiter = model.sampleiter(a.batch_size)
        ts = time.time()
        for i in range(1000):
            model.logprob(next(smpiter), torch.from_numpy(config))[0].mean(0).sum().backward()
        te = time.time()
        print("Time taken: "+str(te-ts))
        
        ts = time.time()
        for i in range(1000):
            causal._causal.sample_mlp(W0,B0,W1,B1,N,config,batch,alpha)
            causal._causal.logprob_mlp(W0,B0,W1,B1,N,np.zeros_like(N),batch,config,out,
                                       dW0,dB0,dW1,dB1,alpha,a.temperature)
        te = time.time()
        print("Time taken: "+str(te-ts))
        #import pdb;pdb.set_trace()


#
# Main
#
if __name__ == '__main__':
    unittest.main()
