import numpy as np
import matplotlib.pyplot as plt
import time
# from kuramoto import Kuramoto, plot_phase_coherence, plot_activity 
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

    

class SpringSim(object):
    def __init__(self, n_balls=5, box_size=5., loc_std=.5, vel_norm=.5,
                 interaction_strength=.1, noise_var=0., coupling = 1.0):
        self.n_balls = n_balls
        self.box_size = box_size
        self.loc_std = loc_std
        self.vel_norm = vel_norm
        self.interaction_strength = interaction_strength
        self.noise_var = noise_var
        self.coupling = coupling

        self._spring_types = np.array([0.,0.5, 1.])
        self._delta_T = 0.001
        self._max_F = 0.1 / self._delta_T

    def _energy(self, loc, vel, edges):
        # disables division by zero warning, since I fix it with fill_diagonal
        with np.errstate(divide='ignore'):

            K = 0.5 * (vel ** 2).sum()
            U = 0
            for i in range(loc.shape[1]):
                for j in range(loc.shape[1]):
                    if i != j:
                        r = loc[:, i] - loc[:, j]
                        dist = np.sqrt((r ** 2).sum())
                        U += 0.5 * self.interaction_strength * edges[
                            i, j] * (dist ** 2) / 2
            return U + K

    def _clamp(self, loc, vel):
        '''
        :param loc: 2xN location at one time stamp
        :param vel: 2xN velocity at one time stamp
        :return: location and velocity after hiting walls and returning after
            elastically colliding with walls
        '''
        assert (np.all(loc < self.box_size * 3))
        assert (np.all(loc > -self.box_size * 3))

        over = loc > self.box_size
        loc[over] = 2 * self.box_size - loc[over]
        assert (np.all(loc <= self.box_size))

        # assert(np.all(vel[over]>0))
        vel[over] = -np.abs(vel[over])

        under = loc < -self.box_size
        loc[under] = -2 * self.box_size - loc[under]
        # assert (np.all(vel[under] < 0))
        assert (np.all(loc >= -self.box_size))
        vel[under] = np.abs(vel[under])

        return loc, vel


    def _l2(self, A, B):
        """
        Input: A is a Nxd matrix
               B is a Mxd matirx
        Output: dist is a NxM matrix where dist[i,j] is the square norm
            between A[i,:] and B[j,:]
        i.e. dist[i,j] = ||A[i,:]-B[j,:]||^2
        """
        A_norm = (A ** 2).sum(axis=1).reshape(A.shape[0], 1)
        B_norm = (B ** 2).sum(axis=1).reshape(1, B.shape[0])
        dist = A_norm + B_norm - 2 * A.dot(B.transpose())
        return dist

    def generate_static_graph(self, spring_prob=[1. / 2, 0, 1. / 2]):
        # Sample edges: without self-loop
        spring_type_static = [0, 1]

        edges = np.random.choice(spring_type_static,
                                 size=(self.n_balls, self.n_balls),
                                 p=spring_prob)
        adj_modified = np.copy(edges)
        print(edges )
        adj_modified[5:, :] = np.where(adj_modified[5:, :]
                                        == 1, 2, adj_modified[5:, :])
        adj_modified[:, 5:] = np.where(adj_modified[:, 5:]
                                        == 1, 2, adj_modified[:,5:])

        print("EXit")
        exit()
        print(adj_modified)
    # return adj_modified
        edges = np.tril(edges) + np.tril(edges, -1).T
        np.fill_diagonal(edges, 0)

        return edges

    def generate_static_graph_hide_n_seek(self, spring_prob=[1./2, 0, 1./2], hide_balls=1):
        "samples edges without self loop and remove nodes in hide_balls to have graph without hidden nodes"
        "returrn tthe graph with all edhes and return the graph without hidden edges with total nodes = n_balls - len(hide_balls)"
        edges = np.random.choice(self._spring_types,
                                 size=(self.n_balls, self.n_balls),
                                 p=spring_prob)
        
        adj_modified = np.copy(edges)
        # print(edges)
        adj_modified[5:, :] = np.where(adj_modified[5:, :]
                                       == 1, self.coupling, adj_modified[5:, :])
        adj_modified[:, 5:] = np.where(adj_modified[:, 5:]
                                       == 1, self.coupling, adj_modified[:, 5:])
        # print(adj_modified)
        # exit()

        edges = np.tril(edges) + np.tril(edges, -1).T
        adj_modified = np.tril(adj_modified) + np.tril(adj_modified, -1).T
        # print(adj_modified)
        # exit()
        np.fill_diagonal(edges, 0)
        np.fill_diagonal(adj_modified, 0)
        edges_full = edges.copy()
        broken_edges = edges.copy()
        broken_edges = broken_edges[:(
            self.n_balls - hide_balls), :(self.n_balls - hide_balls)]
        return broken_edges, edges_full, adj_modified


    def sample_trajectory_static_graph_irregular_difflength_each_no_sampling(self, args, edges, isTrain=True):
        '''
        every node have different observations
        train observation length [ob_min, ob_max]
        :param args:
        :param edges:
        :param isTrain:
        :param sample_freq:
        :param step_train:
        :param step_test:
        :return: 

        ## TODO: Updated on march 10 
        So I am having same criteria for testinmg as fgor training 
        removeing extra boxes and prediction! 
        '''

        sample_freq = args.sample_freq
        ode_step = args.ode
        max_ob = ode_step//sample_freq

        num_test_box = args.num_test_box
        num_test_extra = args.num_test_extra


        ob_max = args.ob_max
        ob_min = args.ob_min
        self.num_steps = 90
        #########Modified sample_trajectory with static graph input, irregular timestamps.

        n = self.n_balls

        if isTrain:
            T = ode_step
        else:
            T = ode_step * (1 + num_test_box)
            # T = ode_step

        step = T//sample_freq

        counter = 1  # reserve initial point
        # Initialize location and velocity
        loc = np.zeros((step, 2, n))
        vel = np.zeros((step, 2, n))
        loc_next = np.random.randn(2, n) * self.loc_std
        vel_next = np.random.randn(2, n)
        v_norm = np.sqrt((vel_next ** 2).sum(axis=0)).reshape(1, -1)
        vel_next = vel_next * self.vel_norm / v_norm
        # self._clamp: eturn: location and velocity after hiting walls and returning after
        # elastically colliding with walls
        loc[0, :, :], vel[0, :, :] = self._clamp(loc_next, vel_next)

        # disables division by zero warning, since I fix it with fill_diagonal
        with np.errstate(divide='ignore'):

            forces_size = - self.interaction_strength * edges
            np.fill_diagonal(forces_size,
                             0)  # self forces are zero (fixes division by zero)
            F = (forces_size.reshape(1, n, n) *
                 np.concatenate((
                     np.subtract.outer(loc_next[0, :],
                                       loc_next[0, :]).reshape(1, n, n),
                     np.subtract.outer(loc_next[1, :],
                                       loc_next[1, :]).reshape(1, n, n)))).sum(
                axis=-1)
            F[F > self._max_F] = self._max_F
            F[F < -self._max_F] = -self._max_F

            vel_next += self._delta_T * F
            # run leapfrog
            for i in range(1, T):
                loc_next += self._delta_T * vel_next
                loc_next, vel_next = self._clamp(loc_next, vel_next)

                if i % sample_freq == 0:
                    loc[counter, :, :], vel[counter, :, :] = loc_next, vel_next
                    counter += 1

                forces_size = - self.interaction_strength * edges
                np.fill_diagonal(forces_size, 0)
                # assert (np.abs(forces_size[diag_mask]).min() > 1e-10)

                F = (forces_size.reshape(1, n, n) *
                     np.concatenate((
                         np.subtract.outer(loc_next[0, :],
                                           loc_next[0, :]).reshape(1, n, n),
                         np.subtract.outer(loc_next[1, :],
                                           loc_next[1, :]).reshape(1, n,
                                                                   n)))).sum(
                    axis=-1)
                F[F > self._max_F] = self._max_F
                F[F < -self._max_F] = -self._max_F
                vel_next += self._delta_T * F
            # Add noise to observations
            loc += np.random.randn(step, 2, self.n_balls) * self.noise_var
            vel += np.random.randn(step, 2, self.n_balls) * self.noise_var

            # sampling

            loc_sample = []
            vel_sample = []
            time_sample = []
            if isTrain:
                for i in range(n):
                    # number of timesteps
                    num_steps = self.num_steps
                    loc_sample.append(loc[:num_steps, :, i])
                    vel_sample.append(vel[:num_steps, :, i])
                    time_sample.append(np.arange(num_steps))

            else:
                for i in range(n):
                    num_steps = self.num_steps
                    loc_sample.append(loc[:num_steps, :, i])
                    vel_sample.append(vel[:num_steps, :, i])
                    time_sample.append(np.arange(num_steps))
                    
                   
            return loc_sample, vel_sample, time_sample





class KuramotoModel(object):
    def __init__(self, n_balls=5, connect_probability=0.5): 
        graph_nx = nx.erdos_renyi_graph(n_balls, connect_probability) 
        graph = nx.to_numpy_array(graph_nx )
        print(graph)
        self.Adjacency = graph 
        self.model = Kuramoto(coupling = 1, dt = 0.01, T =10, n_nodes = len(graph)) 
        self.act_mat = self.model.run(adj_mat = graph) 
        self.n_balls = n_balls 

    
    def sample_trajectory_static_graph_irregular_difflength_each_no_sampling(self, isTrain=True):
        '''
        every node have different observations
        train observation length [ob_min, ob_max]
        :param args:
        :param edges:
        :param isTrain:
        :param sample_freq:
        :param step_train:
        :param step_test:
        :return:
        ''' 

        #########Modified sample_trajectory with static graph input, irregular timestamps.

        n = self.n_balls

        # plot_activity(self.act_mat)

        loc = np.zeros((1000, 1, n)) 
        print(self.act_mat.shape)
        for i in range(n): 
            loc[:,0,i] = self.act_mat[i,:]  

        loc_sample = []
        time_sample = []
        if isTrain:
            for i in range(n):
                # number of timesteps
                num_steps = 1000
                loc_sample.append(loc[:num_steps, :, i])

                time_sample.append(np.arange(num_steps))

        return loc_sample, self.Adjacency, time_sample 






        
        


