import torch





def run_kf(data, kf, device="cpu"):
    '''
        Based on the diagonal case
        Generate a vector [x_hat[t-1] | y[t]]' with both the Kalman filter estimate 
        and the new measurement for each system/batch.
    '''
    ys = data["seqs"]
    A = data["trans"] 
    C = data["meas"]
    Qw = data["Qw"]
    Qv = data["Qv"]
    x0 = data["x0"].transpose(-2,-1).repeat(ys.shape[0], 1) # same initial state for all
    xs = data["states"]
    is_diag = data["is_diag"]
    
    # Move the initial state away from the true state for slow KF convergence
    # TODO: check the way noise is added here. seems like a broadcast might be happening, which I'm not sure is right
    # noise_mean = 0.8
    # for i in range(x0.size(1)):
    #     x0[:, i] += torch.normal(0, 0.2, size = x0[:, 0].shape)
    x00 = torch.zeros_like(x0)

    # Reflect the noise added to x0 into the uncertainty modeled by p0 
    # P_0 = noise_mean * torch.ones(A.shape) # this does not look right, this should be the noise covariance only

    if is_diag:
        P_0 = torch.ones_like(A)
    else:
        P_0 = torch.eye(xs.shape[-1], device=device).repeat(xs.shape[0], 1, 1)

    
    _, y_hat, y_hat_forward, P_t, K_t, x_hat_y = kf.run(ys, x00, P_0, A, C, Qw, Qv, is_diag)

    # Plot errors
    # y_errs1, _ = seqs_vs_ground_truth(y_hat[:,:,:], ys, ys)
    # x_errs1, _ = seqs_vs_ground_truth(x_hat[:,:,:], xs, xs)
    # norm_xs_avg = torch.mean(torch.norm(xs, dim=-1), dim=0)
    # plot_losses(EXP_RESULTS_FOLDER + "/", [y_errs1], ['kalm obs err'], extra_title_info=str(datetime.datetime.now()), x_label='k (iterations)', y_label=r"$\|y_hat - y\|$", y_scale='linear')
    # plot_losses(EXP_RESULTS_FOLDER + "/", [x_errs1, norm_xs_avg], ['kalm state err', 'avg_x_norm/'], extra_title_info=str(datetime.datetime.now()), x_label='k (iterations)', y_label=r"$\|x_{hat} - x\|$", y_scale='linear')
    # plot_losses(EXP_RESULTS_FOLDER + "/", [P_t[0, :, 0]], ['P_t'], extra_title_info=str(datetime.datetime.now()), x_label='t (time)', y_label=r"$\mathcal{P}(t)$")
    return K_t, P_t, y_hat, y_hat_forward, x_hat_y



####### Notations based on this course: https://www.youtube.com/watch?v=CacYop3kLTw&list=PLCVaVjGeO3uGH9JDGB27TFgXoo2pNwzSN
class KalmanFilter:
    '''
        Assumes an autoregressive system of the form 
        x_n = A x_n + w_n
        y_n = C x_n + v_n
        where Qw and Qv are the covariance matrices of the noises w and v. The noises are uncorrelated with each other
        and also temporally uncorrelated among themselves. 

        Importantly, this class assumes that all involved matrices A, C, Qw, Qv are diagonals,
        and hence expressible as single vectors.

        The vector typed are better off being torch, because we're gonna want to compare anyway with 
        the outputs of the transformer, whihch ar etorch tensors.
    '''


    @staticmethod
    def compute_KF_linear_map(seq_len, K_0, A, C, is_diag):
        '''
        Compute the linear transformation between from 
        the ground truth measurements to the measurements estimated according to KF
        (i.e., the linear transformation between sequence to sequence). 
        K_t = [seq_len, dim] if diagonal else [seq_len, dim, dim]
        A = [dim] if diagonal, else, [dim, dim]
        C = [dim] if diagonal, else, [dim, dim]
        '''

        if is_diag:
            K = K_0
            state_dim = K.shape[1]
            A0 = A
            C0 = C
    
            kf_map = torch.zeros((seq_len-1) * state_dim, (seq_len-1) * state_dim, device=K.device)
            # compute each element of the map 
            for i in range(0, seq_len - 1):
                M = torch.ones(state_dim, device=K.device)
                for j in range(i, -1, -1):
                    kf_map[i * state_dim : (i + 1) * state_dim, j * state_dim : (j + 1) * state_dim] = torch.diag(C0 * M * K[i, :])
                    M = (A0 - K[j, :] * A0) * M
    
            return kf_map
        else:
            raise Exception("This method needs to be adapted to also handle non-diagonal cases")


    def __init__(self, device="cpu"):
        self.device = device
        #pass

    def run(self, y, x_0, P_0, A, C, Qw, Qv, is_diag):
        if is_diag:
            return self.__run_diag(y, x_0, P_0, A, C, Qw, Qv)
        
        return self.__run_regular(y, x_0, P_0, A, C, Qw, Qv)
    
    @staticmethod
    def run_P_only(is_diag, seq_len, P_0, A, C, Qw, Qv):
        if is_diag:
            return KalmanFilter.run_diag_P_only(seq_len, P_0, A, C, Qw, Qv)
        
        return KalmanFilter.run_regular_P_only(seq_len, P_0, A, C, Qw, Qv)

    @staticmethod
    def run_regular_P_only(seq_len, P_0, A, C, Qw, Qv):
        assert  A.shape == Qw.shape and A.shape == P_0.shape \
                and C.shape[-2] == Qv.shape[-1] and C.shape[-2] == Qv.shape[-2]
        
        batch_sz = A.shape[0]
        P_n = P_0
        device = A.device
        Id = torch.ones(batch_sz, A.shape[-2], A.shape[-1], device=device)
        K_t = torch.zeros(batch_sz, seq_len, A.shape[-2], C.shape[-2], device=device)
        P_t = torch.zeros(batch_sz, seq_len, A.shape[-2], A.shape[-1], device=device)

        for n in range(seq_len):
            # Prediction step
            P_pred = torch.matmul(torch.matmul(A, P_n), A.transpose(-2, -1)) + Qw #TODO need to be sure Qw is in the right format

            # Observation update step
            CPCT = torch.matmul(torch.matmul(C, P_pred), C.transpose(-2, -1))
            K_n = torch.matmul(P_pred,
                                torch.matmul(C.transpose(-2, -1), torch.inverse(CPCT + Qv)))

            K_t[:, n, :, :] = K_n
            P_n = torch.matmul((Id - torch.matmul(K_n, C)), P_pred)
            P_t[:, n, :, :] = P_n

        return K_t, P_t
    
    @staticmethod
    def run_diag_P_only(seq_len, P_0, A, C, Qw, Qv):
        assert A.shape == C.shape and A.shape == Qw.shape and A.shape == Qv.shape and A.shape == P_0.shape
        batch_sz = A.shape[0]
        P_n = P_0
        C_sqr = C * C
        A_sqr = A * A
        device = A.device
        Id = torch.ones(A.shape[1], device=device)
        K_t = torch.zeros(batch_sz, seq_len, A.shape[1], device=device)
        P_t = torch.zeros(batch_sz, seq_len, A.shape[1], device=device)

        for n in range(seq_len):
            # Prediction step
            P_pred = A_sqr * P_n + Qw

            # Observation update step
            K_n = P_pred * C * (1. / (C_sqr * P_pred + Qv))
            K_t[:, n, :] = K_n
            P_n = (Id - K_n * C) * P_pred
            P_t[:, n, :] = P_n

        return K_t, P_t
    
    def __run_diag(self, y, x_0, P_0, A, C, Qw, Qv):
        '''
            !!! All the concerned matrices here are diagonal!!!!!

            x_0: tensor of shape (batch_size, dim_x)
            y: a tensor of shape (batch_size, seq_len, dim_y)
            A: tensor of shape (batch_size, dim_x)
            C: tensor of shape (batch_size, dim_x). Unfortunately for now we don't do dim reduction
            P_0: tensor of shape (batch_size, dim_x)

            !!! This method assumes that dim_x = dim_y, i.e., the observables are of the same dimension as 
            the states. 


            returns: 
            x_hat: state estimates from step 1 onwards 
            y_hat: observation estimates from step 1 onwards 
            P_t: covariance estimates from step 1 onwards  
            K_t: 
            x_hat_y
        '''

        # TODO: figure if these assertions need to be updated
        # Since all matrices are diagonal, we can reduce all operations here to vector 
        # element-wise type operations
        assert A.shape == C.shape and A.shape == Qw.shape and A.shape == Qv.shape and A.shape == P_0.shape
        assert y.shape[2] == A.shape[1] and y.shape[0] == A.shape[0] and y.shape[2] == x_0.shape[1]

        batch_size = y.shape[0]
        seq_len = y.shape[1]
        dim_x = x_0.shape[1]
        dim_y = y.shape[2]

        x_hat = torch.zeros(batch_size, seq_len, x_0.shape[1], device=self.device)
        y_hat = torch.zeros(batch_size, seq_len, x_0.shape[1], device=self.device)
        y_hat_forward = torch.zeros(batch_size, seq_len, x_0.shape[1], device=self.device)
        # x_hat_y is a sequence of Kalman filer estimates and new measurements {[x_hat, y]}
        x_hat_y = torch.zeros(batch_size, seq_len, dim_x + dim_y, device=self.device)
        
        x_n = x_0
        P_n = P_0
        C_sqr = C * C
        A_sqr = A * A
        Id = torch.ones(x_n.shape, device=self.device)

        K_t = torch.zeros(batch_size, seq_len, A.shape[1], device=self.device)
        P_t = torch.zeros(batch_size, seq_len, A.shape[1], device=self.device)

        for n in range(seq_len):
            # Prediction step
            x_pred = A * x_n
            P_pred = A_sqr * P_n + Qw

            # Observation update step
            K_n = P_pred * C * (1. / (C_sqr * P_pred + Qv))
            # TODO n starts from 1 and y[:, n, :] is used. This means y[:, 0, :] is never used. Do the indices match here?
            x_n = x_pred + K_n * (y[:, n, :] - C * x_pred)
            P_n = (Id - K_n * C) * P_pred

            # record values
            x_hat[:, n, :] = x_n
            y_hat[:, n, :] = C * x_n
            #y_hat_forward[:, n, :] = C * x_pred
            y_hat_forward[:, n, :] = C * A * x_n
            x_hat_y[:, n, 0 : dim_x] = x_n
            x_hat_y[:, n, dim_x : (dim_x + dim_y)] = y[:, n, :]
            
            P_t[:, n, :] = P_n
            K_t[:, n, :] = K_n
        
        # print("Residual term involving x0 = " + str(torch.norm(C * A_to_the_T * x_0, dim=1)))

        return x_hat, y_hat, y_hat_forward, P_t, K_t, x_hat_y

    def __run_regular(self, y, x_0, P_0, A, C, Qw, Qv):
        '''
            !!! This function handles full matrices !!!!!

            x_0: tensor of shape (batch_size, dim_x)
            y: a tensor of shape (batch_size, seq_len, dim_y)
            A: tensor of shape (batch_size, dim_x, dim_x)
            C: tensor of shape (batch_size, dim_y, dim_x). 
            P_0: tensor of shape (batch_size, dim_x, dim_x)

            returns: 
            x_hat: state estimates from step 1 onwards 
            y_hat: observation estimates from step 1 onwards 
            P_t: covariance estimates from step 1 onwards  
            K_t: 
            x_hat_y
        '''
        # TODO: figure if these assertions need to be updated
        # assert A.shape[-1] == C.shape[-1] and A.shape == Qw.shape \
        #         and A.shape == Qv.shape and A.shape == P_0.shape

        batch_size = y.shape[0]
        seq_len = y.shape[1]
        dim_x = x_0.shape[1]
        dim_y = y.shape[2]

        x_hat = torch.zeros(batch_size, seq_len, x_0.shape[1], device=self.device)
        y_hat = torch.zeros(batch_size, seq_len, y.shape[2], device=self.device)
        y_hat_forward = torch.zeros(batch_size, seq_len, y.shape[2], device=self.device)
        # x_hat_y is a sequence of Kalman filer estimates and new measurements {[x_hat, y]}
        x_hat_y = torch.zeros(batch_size, seq_len, dim_x + dim_y, device=self.device)
        
        x_n = x_0.unsqueeze(-1)
        P_n = P_0
        
        K_t = torch.zeros(batch_size, seq_len, dim_x, dim_y, device=self.device)
        P_t = torch.zeros(batch_size, seq_len, dim_x, dim_x, device=self.device)
        Id = torch.eye(dim_x, device=self.device).unsqueeze(0).repeat(batch_size, 1, 1)

        for n in range(seq_len):
            # Prediction step
            x_pred = torch.matmul(A, x_n) #[batch, dimx, 1]
            #print("A shape: ", A.shape)
            #print("P_n shape: ", P_n.shape)
            P_pred = torch.matmul(torch.matmul(A, P_n), A.transpose(-2, -1)) + Qw #TODO need to be sure Qw is in the right format

            # Observation update step
            CPCT = torch.matmul(torch.matmul(C, P_pred), C.transpose(-2, -1))
            #print("CPCT shape: ", CPCT.shape)
            #print("Qv shape: ", Qv.shape)
            #print("CPCT + Qv: ", (CPCT + Qv)[0, :, :])
            K_n = torch.matmul(P_pred,
                                torch.matmul(C.transpose(-2, -1), torch.inverse(CPCT + Qv)))
            
            x_n = x_pred + torch.matmul(K_n , (y[:, n, :].unsqueeze(-1) - torch.matmul(C, x_pred)))
            P_n = torch.matmul((Id - torch.matmul(K_n, C)), P_pred)
            BBB = 0.5 * (P_n + P_n.transpose(-2, -1))
            P_n = BBB

            # record values
            x_hat[:, n, :] = x_n[:, :, 0]
            y_hat[:, n, :] = torch.matmul(C, x_n)[:, :, 0]
            #y_hat_forward[:, n, :] = torch.matmul(C, x_pred)[:, :, 0]
            y_hat_forward[:, n, :] = torch.matmul(torch.matmul(C, A), x_n)[:, :, 0]
            x_hat_y[:, n, 0 : dim_x] = x_n[:, :, 0]
            x_hat_y[:, n, dim_x : (dim_x + dim_y)] = y[:, n, :]
            
            P_t[:, n, :, :] = P_n
            K_t[:, n, :, :] = K_n


        return x_hat, y_hat, y_hat_forward, P_t, K_t, x_hat_y
