import torch
import matplotlib.pyplot as plt
import numpy as np
import random


class TransitionOne:
    def __init__(self):
        self.w = torch.tensor([[2, -2, 0, -2],[0, 0 ,2, 1],[0, 0, 2, -1]], dtype = torch.float)

    def __call__(self, s, inp):
        s = torch.cat([s, inp], dim = 1)
        s = torch.tanh(torch.matmul(s, self.w.transpose(0,1)))
        o = (s[:, 0] >= 0).type(torch.long).squeeze()
        return s, o

def run_seq(s_init, tfunc, iiter):
    s = s_init
    for ii in range(iiter):
        s, o = tfunc(s, torch.ones([s.shape[0], 1]))
    return s, o


if __name__ == "__main__":
    fig = plt.figure()
    tfunc = TransitionOne()
    for jj, seed in enumerate([12, 435, 54, 5634, 4634, 457,243,665, 423]):
        print("================={} : {} ===============".format(jj, seed))
        num_iter = 100
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

        s_init = torch.normal(torch.ones([1000, 3])*torch.tensor([-0.9854, -0.7573, -0.9950]), torch.ones([1000, 3])*torch.tensor([0.1, 0.1, 0.1]))

        ss, oo = run_seq(s_init, tfunc, num_iter)
        color = torch.tensor(list(range(1000)))
        ax = fig.add_subplot(3, 6, 2*jj+1, projection='3d')

        ax.scatter(s_init[:,0].numpy(), s_init[:,1].numpy(), s_init[:,2].numpy(), 
                        c=color, marker="o")
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
        #unique_outputs = torch.unique(oo)

        # for n, uo in enumerate(unique_outputs):
        #     ii = (oo==uo).nonzero().squeeze()
        #     sampled_ss = ss[ii].squeeze()
        #     sampled_c = color[ii].squeeze().tolist()
            # ax.scatter(sampled_ss[:,0].numpy(), sampled_ss[:, 1].numpy(), sampled_ss[:, 2].numpy(), 
            #             c=sampled_c, marker=n)

        print(ss.shape)
        print(s_init[:10])
        print(ss[:10])

        ax = fig.add_subplot(3, 6, 2*jj+2, projection='3d')

        ax.scatter(ss[:,0].numpy(), ss[:,1].numpy(), ss[:,2].numpy(), 
                        c=color, marker="o")
        
        # Make legend, set axes limits and labels
        ax.legend()
        # ax.set_xlim(-1.1, -1)
        # ax.set_ylim(-0.5, 0.5)
        # ax.set_zlim(-1.1, -1)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')

    # Customize the view angle so it's easier to see that the scatter points lie
    # on the plane y=0
    ax.view_init(elev=20., azim=-35, roll=0)

    plt.show()