import torch
import numpy
import matplotlib.pyplot as plt
import models
from scipy.stats import ortho_group
from filterpy.kalman import KalmanFilter
import pickle

def Gen_data(device='cuda', batch_size=64, eval=False, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1):
    w_eval = w_sigma*torch.randn((input_dim, 1));
    for i in range(batch_size):
        w_train = w_sigma*torch.randn((input_dim, 1));
        if eval:
            w=w_eval
        else:
            w=w_train
        x = x_sigma*torch.randn((input_dim, chunk_size))
        y = torch.matmul(torch.transpose(w, 0, 1), x);
        inputs = torch.zeros((input_dim + 1, 2 * chunk_size))
        inputs[1:input_dim + 1, 0:2 * chunk_size:2] = x
        inputs[0, 1:2 * chunk_size:2] = y
        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concatenate((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    outputs_batch = torch.zeros((batch_size, 2 * chunk_size))
    outputs_batch[:, 0:2 * chunk_size:2] = inputs_batch[:, 0, 1:2 * chunk_size:2]


    inputs_batch = torch.transpose(inputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)
    return inputs_batch, outputs_batch

def Gen_data_SS(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0):
    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
        else:
            F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)

        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)
        w = torch.unsqueeze(w_t_m_1, dim=0);

        for k in range(int(chunk_size) - 1):
            w_t_m_1 = torch.matmul(F, w_t_m_1);
            w = torch.concatenate((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)

        x = x_sigma * torch.randn((input_dim, chunk_size), dtype=float)

        if d_curr < input_dim:
            x[(d_curr - 1):, :] = 0.0;
        y = torch.diag(torch.matmul(torch.squeeze(w, dim=-1), x));
        inputs = torch.zeros((input_dim + 1, 2 * chunk_size + input_dim))
        inputs[1:input_dim + 1, 0:input_dim] = F
        inputs[1:input_dim + 1, input_dim:2 * chunk_size + input_dim:2] = x

        inputs[0, input_dim + 1:2 * chunk_size + input_dim:2] = y
        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concatenate((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    outputs_batch = torch.zeros((batch_size, 2 * chunk_size + input_dim))
    outputs_batch[:, input_dim:2 * chunk_size + input_dim:2] = inputs_batch[:, 0,
                                                               input_dim + 1:2 * chunk_size + input_dim:2]

    inputs_batch = torch.transpose(inputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)
    return inputs_batch, outputs_batch


def Gen_data_SS_innovation_noise(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0, alpha_Q=0.0):
    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
            A_Q = 0.0*torch.eye(input_dim, dtype=float);
        else:
            F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)
            Sigma_Q_sqrt=torch.sqrt(alpha_Q*torch.diag(torch.rand((input_dim,),dtype=float)))
            Q_U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
            A_Q=torch.matmul(Q_U, Sigma_Q_sqrt);
            Q=torch.matmul(A_Q,A_Q.T)


        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)
        w = torch.unsqueeze(w_t_m_1, dim=0);

        for k in range(int(chunk_size) - 1):
            innovation_noise=torch.matmul(A_Q,torch.randn((input_dim,1), dtype=float));
            w_t_m_1 = torch.matmul(F, w_t_m_1)+innovation_noise;
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)

        x = x_sigma * torch.randn((input_dim, chunk_size), dtype=float)

        if d_curr < input_dim:
            x[(d_curr - 1):, :] = 0.0;
        y = torch.diag(torch.matmul(torch.squeeze(w, dim=-1), x));
        inputs = torch.zeros((input_dim + 1, 2 * chunk_size + 2*input_dim))
        inputs[1:input_dim + 1, 0:input_dim] = F
        inputs[1:input_dim + 1, input_dim:2*input_dim] = Q
        inputs[1:input_dim + 1, (2*input_dim):(2 * chunk_size + 2*input_dim):2] = x

        inputs[0, (2*input_dim + 1):(2 * chunk_size + 2*input_dim):2] = y
        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    outputs_batch = torch.zeros((batch_size, 2 * chunk_size + 2*input_dim))
    outputs_batch[:, (2*input_dim):(2 * chunk_size + 2*input_dim):2] = inputs_batch[:, 0,
                                                                       (2*input_dim + 1):(2 * chunk_size + 2*input_dim):2]

    inputs_batch = torch.transpose(inputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)
    return inputs_batch, outputs_batch

def Gen_data_SS_innovation_noise_obs_noise(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0, alpha_Q=0.0, alpha_R=0.0):

    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
            A_Q = 0.0*torch.eye(input_dim, dtype=float);
        else:
            F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)
            Sigma_Q_sqrt=torch.sqrt(alpha_Q*torch.diag(torch.rand((input_dim,),dtype=float)))
            Q_U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
            A_Q=torch.matmul(Q_U, Sigma_Q_sqrt);
            Q=torch.matmul(A_Q,A_Q.T)


        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)
        w = torch.unsqueeze(w_t_m_1, dim=0);

        for k in range(int(chunk_size) - 1):
            innovation_noise=torch.matmul(A_Q,torch.randn((input_dim,1), dtype=float));
            w_t_m_1 = torch.matmul(F, w_t_m_1)+innovation_noise;
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)

        x = x_sigma * torch.randn((input_dim, chunk_size), dtype=float)
        noise_var=alpha_R*torch.rand((1,), dtype=float)
        obs_noise=torch.sqrt(noise_var)*torch.randn((chunk_size,), dtype=float)


        if d_curr < input_dim:
            x[(d_curr - 1):, :] = 0.0;
        y = torch.diag(torch.matmul(torch.squeeze(w, dim=-1), x))+obs_noise;
        inputs = torch.zeros((input_dim + 1, 2 * chunk_size + 2*input_dim+1))

        inputs[1:input_dim + 1, 0:input_dim] = F
        inputs[1:input_dim + 1, input_dim:2*input_dim] = Q
        inputs[0, 2*input_dim+1] = noise_var;
        inputs[1:input_dim + 1, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = x

        inputs[0, (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2] = y
        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    outputs_batch = torch.zeros((batch_size, 2 * chunk_size + 2*input_dim+1))
    outputs_batch[:, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = inputs_batch[:, 0,
                                                                       (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2]

    inputs_batch = torch.transpose(inputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)
    return inputs_batch, outputs_batch


def Gen_data_SS_innovation_noise_obs_noise_state_est_curr(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0, alpha_Q=0.0, alpha_R=0.0):


    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
            A_Q = 0.0*torch.eye(input_dim, dtype=float);
        else:
            F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)
            Sigma_Q_sqrt=torch.sqrt(alpha_Q*torch.diag(torch.rand((input_dim,),dtype=float)))
            Q_U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
            A_Q=torch.matmul(Q_U, Sigma_Q_sqrt);
            Q=torch.matmul(A_Q,A_Q.T)


        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)
        w = torch.unsqueeze(w_t_m_1, dim=0);

        for k in range(int(chunk_size) - 1):
            innovation_noise=torch.matmul(A_Q,torch.randn((input_dim,1), dtype=float));
            w_t_m_1 = torch.matmul(F, w_t_m_1)+innovation_noise;
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)

        x = x_sigma * torch.randn((input_dim, chunk_size), dtype=float)
        noise_var=alpha_R*torch.rand((1,), dtype=float)
        obs_noise=torch.sqrt(noise_var)*torch.randn((chunk_size,), dtype=float)

        if i==0:
            w_batch=torch.unsqueeze(torch.squeeze(torch.transpose(w, 0,1), dim=-1), dim=0);
        else:
            w_batch=torch.concat((w_batch,torch.unsqueeze(torch.squeeze(torch.transpose(w, 0,1), dim=-1), dim=0)), dim=0)


        if d_curr < input_dim:
            x[(d_curr - 1):, :] = 0.0;
        y = torch.diag(torch.matmul(torch.squeeze(w, dim=-1), x))+obs_noise;
        inputs = torch.zeros((input_dim + 1, 2 * chunk_size + 2*input_dim+1))

        inputs[1:input_dim + 1, 0:input_dim] = F
        inputs[1:input_dim + 1, input_dim:2*input_dim] = Q
        inputs[0, 2*input_dim+1] = noise_var;
        inputs[1:input_dim + 1, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = x

        inputs[0, (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2] = y
        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    outputs_batch = torch.zeros((batch_size,input_dim, 2 * chunk_size + 2*input_dim+1))
    outputs_batch[:, 0:, (2*input_dim+2):(2 * chunk_size + 2*input_dim+1):2] = w_batch

    inputs_batch = torch.transpose(inputs_batch, 1, 2);
    outputs_batch = torch.transpose(outputs_batch, 1, 2);


    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)
    return inputs_batch, outputs_batch



def Gen_data_SS_innovation_noise_obs_noise_F_options(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0, alpha_Q=0.0, alpha_R=0.0, F_option=2):

    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
            A_Q = 0.0*torch.eye(input_dim, dtype=float);
        else:
            if F_option==2:
                U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F=torch.diag(torch.rand((input_dim,),dtype=float));
                F=torch.matmul(torch.matmul(U,Sigma_F), U.T)

            elif F_option==3:
                U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F = torch.diag(2.*torch.rand((input_dim,), dtype=float)-1.);
                F = torch.matmul(torch.matmul(U, Sigma_F), U.T)

            else:
                F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)

            Sigma_Q_sqrt=torch.sqrt(alpha_Q*torch.diag(torch.rand((input_dim,),dtype=float)))
            Q_U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
            A_Q=torch.matmul(Q_U, Sigma_Q_sqrt);
            Q=torch.matmul(A_Q,A_Q.T)


        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)
        w = torch.unsqueeze(w_t_m_1, dim=0);

        for k in range(int(chunk_size) - 1):
            innovation_noise=torch.matmul(A_Q,torch.randn((input_dim,1), dtype=float));
            w_t_m_1 = torch.matmul(F, w_t_m_1)+innovation_noise;
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)

        x = x_sigma * torch.randn((input_dim, chunk_size), dtype=float)
        noise_var=alpha_R*torch.rand((1,), dtype=float)
        obs_noise=torch.sqrt(noise_var)*torch.randn((chunk_size,), dtype=float)


        if d_curr < input_dim:
            x[(d_curr - 1):, :] = 0.0;
        y = torch.diag(torch.matmul(torch.squeeze(w, dim=-1), x))+obs_noise;
        inputs = torch.zeros((input_dim + 1, 2 * chunk_size + 2*input_dim+1))

        inputs[1:input_dim + 1, 0:input_dim] = F
        inputs[1:input_dim + 1, input_dim:2*input_dim] = Q
        inputs[0, 2*input_dim+1] = noise_var;
        inputs[1:input_dim + 1, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = x

        inputs[0, (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2] = y
        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    outputs_batch = torch.zeros((batch_size, 2 * chunk_size + 2*input_dim+1))
    outputs_batch[:, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = inputs_batch[:, 0,
                                                                       (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2]

    inputs_batch = torch.transpose(inputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)
    return inputs_batch, outputs_batch


def Gen_data_SS_innovation_noise_obs_noise_F_options_non_scalar_y(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0, alpha_Q=0.0, alpha_R=0.0, F_option=2, y_dim=1):

    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
            A_Q = 0.0*torch.eye(input_dim, dtype=float);
        else:
            if F_option==2:
                U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F=torch.diag(torch.rand((input_dim,),dtype=float));
                F=torch.matmul(torch.matmul(U,Sigma_F), U.T)

            elif F_option==3:
                U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F = torch.diag(2.*torch.rand((input_dim,), dtype=float)-1.);
                F = torch.matmul(torch.matmul(U, Sigma_F), U.T)

            else:
                F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)

            Sigma_Q_sqrt=torch.sqrt(alpha_Q*torch.diag(torch.rand((input_dim,),dtype=float)))
            Q_U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
            A_Q=torch.matmul(Q_U, Sigma_Q_sqrt);
            Q=torch.matmul(A_Q,A_Q.T)

        noise_var = alpha_R * torch.rand((y_dim,), dtype=float)
        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)

        x_t = x_sigma * torch.randn((y_dim, input_dim), dtype=float)
        if d_curr < input_dim:
            x_t[:, (d_curr - 1):] = 0.0;
        obs_noise = torch.matmul(torch.diag(torch.sqrt(noise_var)) , torch.randn((y_dim,1), dtype=float))
        y = torch.matmul(x_t, w_t_m_1)+obs_noise
        x = torch.reshape(x_t, (y_dim * input_dim, 1))

        w = torch.unsqueeze(w_t_m_1, dim=0);



        for k in range(int(chunk_size) - 1):
            innovation_noise=torch.matmul(A_Q,torch.randn((input_dim,1), dtype=float));
            w_t_m_1 = torch.matmul(F, w_t_m_1)+innovation_noise;
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)
            x_t = x_sigma * torch.randn((y_dim, input_dim), dtype=float)
            if d_curr < input_dim:
                x_t[:, (d_curr - 1):] = 0.0;
            obs_noise = torch.matmul(torch.diag(torch.sqrt(noise_var)), torch.randn((y_dim, 1), dtype=float))
            y_t = torch.matmul(x_t, w_t_m_1) + obs_noise
            x=torch.concat((x,torch.reshape(x_t, (y_dim * input_dim, 1))), dim=-1)
            y= torch.concat((y,y_t), dim=-1)





        inputs = torch.zeros(((input_dim + 1)*y_dim, 2 * chunk_size + 2*input_dim+1))

        inputs[y_dim:input_dim + y_dim, 0:input_dim] = F.T
        inputs[y_dim:input_dim + y_dim, input_dim:2*input_dim] = Q.T
        inputs[0:y_dim, 2*input_dim] = noise_var;
        inputs[y_dim:y_dim*(input_dim + 1), (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = x

        inputs[0:y_dim, (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2] = y
        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    outputs_batch = torch.zeros((batch_size,y_dim, 2 * chunk_size + 2*input_dim+1))
    outputs_batch[:, :, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = inputs_batch[:, 0:y_dim,
                                                                       (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2]

    inputs_batch = torch.transpose(inputs_batch, 1, 2);
    outputs_batch = torch.transpose(outputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)

    return inputs_batch, outputs_batch


def KNN(inputs_batch, n=3, device='cuda'):
    chunk_size = int(inputs_batch.shape[-2] / 2)
    input_dim = inputs_batch.shape[-1] - 1
    batch_size = inputs_batch.shape[0];
    query = inputs_batch[:, 2 * chunk_size - 2, 1:]
    training_data = inputs_batch[:, 0:2 * chunk_size - 2:2, 1:]
    targets = inputs_batch[:, 1:2 * chunk_size - 1:2, 0]

    sort_indices=torch.argsort(torch.sum((torch.unsqueeze(query, dim=-2) - training_data) ** 2, dim=-1), dim=-1)

    prediction=[]
    for i in range(batch_size):
        pred=targets[i][sort_indices[i]]
        prediction+=[torch.mean(pred[:n])]

    prediction=torch.tensor(prediction, device=device)
    MSE=torch.mean((prediction-inputs_batch[:, 2 * chunk_size - 1, 0])**2)


    return prediction, MSE

def Stochastic_Gradient_Descent_Regression(inputs_batch, alpha=0.01, device='cuda'):
    chunk_size=int(inputs_batch.shape[-2]/2)
    input_dim=inputs_batch.shape[-1]-1
    batch_size=inputs_batch.shape[0];
    W=torch.zeros((batch_size, input_dim), device =device)
    for i in range(chunk_size-1):
        W=W-2*alpha*(inputs_batch[:,2*i,1:]*torch.unsqueeze(torch.diag(torch.matmul(W,inputs_batch[:,2*i,1:].T)),dim=-1)-inputs_batch[:,2*i,1:]*torch.unsqueeze(inputs_batch[:,2*i+1,0], dim=-1))


    MSE=torch.mean((inputs_batch[:,2*chunk_size-1,0] - torch.diag(torch.matmul(W,inputs_batch[:,2*chunk_size-2,1:].T)))**2)

    return W, MSE


def Stochastic_Gradient_Descent_Regression_explicit(inputs_batch, final_state, alpha=0.01, device='cuda'):
    chunk_size = int(inputs_batch.shape[-2] / 2)
    input_dim = inputs_batch.shape[-1] - 1
    W = torch.zeros((1, input_dim), device=device)

    for i in range(chunk_size):
        W = W - 2 * alpha * (inputs_batch[:, 2 * i, 1:] * torch.unsqueeze(
            torch.diag(torch.matmul(W, inputs_batch[:, 2 * i, 1:].T)), dim=-1) - inputs_batch[:, 2 * i,1:] * torch.unsqueeze(inputs_batch[:, 2 * i + 1, 0], dim=-1))


    MSE=torch.sqrt(sum(torch.squeeze((W.T-final_state)**2, dim=-1))).cpu()

    return W, MSE



def Stochastic_Gradient_Descent_Regression_one_step(inputs_batch, alpha=0.01, device='cuda'):
    chunk_size = int(inputs_batch.shape[-2] / 2)
    input_dim = inputs_batch.shape[-1] - 1
    W = torch.zeros((1, input_dim), device=device)

    for i in range(chunk_size-1):
        W = W - 2 * alpha * (inputs_batch[:, 2 * i, 1:] * torch.unsqueeze(
            torch.diag(torch.matmul(W, inputs_batch[:, 2 * i, 1:].T)), dim=-1) - inputs_batch[:, 2 * i,1:] * torch.unsqueeze(inputs_batch[:, 2 * i + 1, 0], dim=-1))

    out = torch.matmul(W, inputs_batch[:, -2, 1:].T)
    # MSE=torch.sqrt(sum(torch.squeeze((W.T-final_state)**2, dim=-1))).cpu()
    MSE=(out-inputs_batch[:,-1,0])**2

    return out, MSE.cpu()

def Stochastic_Gradient_Descent_Regression_one_step_non_scalar(inputs_batch, alpha=0.01, device='cuda', y_dim=2):
    chunk_size = int(inputs_batch.shape[-2] / 2)
    input_dim = int((inputs_batch.shape[-1])/y_dim-1)

    W = torch.zeros((1, input_dim), device=device)



    for i in range(chunk_size-1):
        X = torch.reshape(inputs_batch[:, 2 * i, y_dim:], (1, y_dim, input_dim))
        Y = torch.reshape(inputs_batch[:, 2 * i+1, :y_dim], (1,y_dim,1))
        W=W-2*alpha*(torch.squeeze(torch.matmul(torch.matmul(torch.transpose(X,1,2),X), torch.unsqueeze(W, dim=-1)), dim=-1) - torch.squeeze(torch.matmul(torch.transpose(X,1,2), Y), dim=-1) )

    X = torch.reshape(inputs_batch[:, -2, y_dim:], (1, y_dim, input_dim))
    out = torch.squeeze(torch.matmul(X, torch.unsqueeze(W, dim=-1)), -1)


    # MSE=torch.sqrt(sum(torch.squeeze((W.T-final_state)**2, dim=-1))).cpu()
    MSE=(out-inputs_batch[:,-1,:y_dim])**2

    return out



def Ridge_Regression(inputs_batch, lambda_=0.1, device='cuda'):
    chunk_size=int(inputs_batch.shape[-2]/2)
    input_dim=inputs_batch.shape[-1]-1
    batch_size=inputs_batch.shape[0];
    W = torch.zeros((batch_size, input_dim), device=device)

    for i in range(batch_size):
        X=inputs_batch[i, 0:2 * chunk_size - 2:2, 1:]
        Y=inputs_batch[i, 1:2 * chunk_size - 1:2, 0].T
        try:
            inverse_term=torch.linalg.inv(torch.matmul(X.T,X)+lambda_*torch.eye(input_dim, device=device))
        except:
            continue

        cross_term=torch.matmul(X.T,Y);
        W[i,:]=torch.matmul(inverse_term, cross_term)

    MSE = torch.mean((inputs_batch[:, 2 * chunk_size - 1, 0] - torch.diag(torch.matmul(W, inputs_batch[:, 2 * chunk_size - 2, 1:].T))) ** 2)
    return W, MSE

def Ridge_Regression_explicit(inputs_batch, final_state, lambda_=0.1, device='cuda'):
    chunk_size=int(inputs_batch.shape[-2]/2)
    input_dim=inputs_batch.shape[-1]-1
    batch_size=inputs_batch.shape[0];
    W = torch.zeros((batch_size, input_dim), device=device)

    for i in range(batch_size):
        X=inputs_batch[i, 0:2 * chunk_size:2, 1:]
        Y=inputs_batch[i, 1:2 * chunk_size:2, 0].T
        try:
            inverse_term=torch.linalg.inv(torch.matmul(X.T,X)+lambda_*torch.eye(input_dim, device=device))
        except:
            continue

        cross_term=torch.matmul(X.T,Y);
        W[i,:]=torch.matmul(inverse_term, cross_term)

    MSE=torch.sqrt(sum(torch.squeeze((W.T-final_state)**2, dim=-1))).cpu()
    return W, MSE


def Ridge_Regression_one_step(inputs_batch, lambda_=0.1, device='cuda'):
    chunk_size=int(inputs_batch.shape[-2]/2)
    input_dim=inputs_batch.shape[-1]-1
    batch_size=inputs_batch.shape[0];
    W = torch.zeros((batch_size, input_dim), device=device)

    for i in range(batch_size):
        X=inputs_batch[i, 0:2 * chunk_size-2 :2, 1:]
        Y=inputs_batch[i, 1:2 * chunk_size-1:2, 0].T
        try:
            inverse_term=torch.linalg.inv(torch.matmul(X.T,X)+lambda_*torch.eye(input_dim, device=device))
        except:
            continue

        cross_term=torch.matmul(X.T,Y);
        W[i,:]=torch.matmul(inverse_term, cross_term)

    out = torch.matmul(W, inputs_batch[:, -2, 1:].T)
    # MSE=torch.sqrt(sum(torch.squeeze((W.T-final_state)**2, dim=-1))).cpu()
    MSE=(out-inputs_batch[:,-1,0])**2

    return out, MSE.cpu()


def Ridge_Regression_one_step_non_scalar(inputs_batch, lambda_=0.1, device='cuda', y_dim=2):
    chunk_size=int(inputs_batch.shape[-2]/2)
    input_dim=int((inputs_batch.shape[-1])/y_dim-1)
    batch_size=inputs_batch.shape[0];



    for k in range(chunk_size-1):
       if k==0:
           X = torch.reshape(inputs_batch[:, 2 * k, y_dim:], (1, y_dim, input_dim))
           Y = torch.reshape(inputs_batch[:, 2 * k+1, :y_dim], (1, y_dim, 1))
       else:
           X_k = torch.reshape(inputs_batch[:, 2 * k, y_dim:], (1, y_dim, input_dim))
           Y_k = torch.reshape(inputs_batch[:, 2 * k + 1, :y_dim], (1, y_dim, 1))
           X=torch.cat((X,X_k), dim=1)
           Y = torch.cat((Y, Y_k), dim=1)


    X=torch.squeeze(X,dim=0)
    Y = torch.squeeze(Y, dim=0)

    inverse_term=torch.linalg.inv(torch.matmul(X.T,X)+lambda_*torch.eye(input_dim, device=device))



    cross_term=torch.matmul(X.T,Y);
    W=torch.matmul(inverse_term, cross_term)
    X = torch.reshape(inputs_batch[:, -2, y_dim:], (y_dim, input_dim))

    out = torch.unsqueeze(torch.matmul( X, W), dim=0)

    return out





def perform_kalman_filtering(h,y, F, input_dim=8, chunk_size=40, Q=0.0*torch.eye(8), R=0.0):
    f = KalmanFilter(dim_x=input_dim, dim_z=1)

    f.R=R.cpu().numpy()
    f.F=F.cpu()
    f.Q=Q.cpu().numpy();
    for i in range(chunk_size-1):
        f.H=torch.unsqueeze(h[2*i,1:], dim=0).cpu()
        f.predict()
        y_i=torch.unsqueeze(y[2*i], dim=0).cpu()
        f.update(y_i)

    f.predict()
    x=torch.tensor(f.x);
    h=torch.unsqueeze(h[2*chunk_size-2,1:], dim=0).cpu()

    out=torch.matmul(h.float(),x.float());
    MSE=abs(y[2*chunk_size-2].cpu()-out);
    return out,MSE




def perform_kalman_filtering_non_scalar(h,y, F, input_dim=8, chunk_size=40, Q=0.0*torch.eye(8), R=0.0, y_dim=2):
    f = KalmanFilter(dim_x=input_dim, dim_z=y_dim)

    f.R=R.cpu().numpy()
    f.F=F.cpu()
    f.Q=Q.cpu().numpy();
    for i in range(chunk_size-1):
        # f.H=torch.unsqueeze(h[2*i,y_dim:], dim=0).cpu()
        f.H=torch.reshape(h[2*i,y_dim:], (y_dim, input_dim)).cpu()
        f.predict()
        y_i=torch.unsqueeze(y[2*i,:], dim=-1).cpu()
        f.update(y_i)

    f.predict()
    x=torch.tensor(f.x);
    # h=torch.unsqueeze(h[2*chunk_size-2,1:], dim=0).cpu()
    h=torch.reshape(h[2*chunk_size-2, y_dim:], (y_dim, input_dim)).cpu()

    out=torch.matmul(h.float(),x.float());

    return out




def perform_kalman_filtering_explicit(h, F, final_state, input_dim=8, chunk_size=40, Q=0.0 * torch.eye(8), R=0.0):
    f = KalmanFilter(dim_x=input_dim, dim_z=1)

    f.R = R.cpu().numpy()
    f.F = F.cpu()
    f.Q = Q.cpu().numpy();
    for i in range(chunk_size):
        f.H = torch.unsqueeze(h[2 * i, 1:], dim=0).cpu()
        f.predict()
        y_i = torch.unsqueeze(h[2 * i+1, 0], dim=0).cpu()
        f.update(y_i)


    x = torch.tensor(f.x);
    MSE = numpy.sqrt(sum(final_state.cpu() - x)**2);
    return x, MSE


loss_fcn=torch.nn.MSELoss()

trained_model=torch.load('Saved_Model/linear_regression_innovation_noise_obs_noise_F_option_1_no_stats_at_all.pt')
trained_model.eval()

trained_model=trained_model.to('cuda')
drop_mode='All' # 'Noise', 'All'
input_dim = 8
y_dim = 1;
option=1 # Option 1 for Strategy 1 and 3 for Strategy 2
drop_stats = True
param_vec = [0.025]



MSPD_list_cz_param = [];
MSPD_list_cz_SGD_0_pt_01_param = []
MSPD_list_cz_SGD_0_pt_05_param = []
MSPD_list_cz_Ridge_0_pt_01_param = []
MSPD_list_cz_Ridge_0_pt_05_param = []
MSPD_list_cz_OLS_param = []

for Q_alpha_factor in param_vec:
    MSPD_list_cz = [];
    MSPD_list_cz_SGD_0_pt_01 = []
    MSPD_list_cz_SGD_0_pt_05 = []
    MSPD_list_cz_Ridge_0_pt_01 = []
    MSPD_list_cz_Ridge_0_pt_05 = []
    MSPD_list_cz_OLS = []
    for chunk_size in numpy.arange(2, 42):
        print(chunk_size)
        MSPD_list = []
        MSPD_list_SGD_0_pt_01 = []
        MSPD_list_Ridge_0_pt_01 = []
        MSPD_list_SGD_0_pt_05 = []
        MSPD_list_Ridge_0_pt_05 = []
        MSPD_list_OLS = []
        MSE_transformer_list = [];
        for i in range(5000):

            F_alpha = torch.rand((1,))
            Q_alpha = Q_alpha_factor * torch.rand((1,))
            R_alpha = 0.025 * torch.rand((1,))
            inputs_batch, outputs_batch = Gen_data_SS_innovation_noise_obs_noise_F_options_non_scalar_y(
                chunk_size=chunk_size, batch_size=1, alpha_F=F_alpha, alpha_Q=Q_alpha, alpha_R=R_alpha, F_option=option, y_dim=y_dim)

            R = torch.diag(inputs_batch[0, 2 * input_dim, :y_dim])
            Q = inputs_batch[0, input_dim:2 * input_dim, y_dim:y_dim + input_dim]
            F = inputs_batch[0, :input_dim, y_dim:y_dim + input_dim]
            out_filtering = perform_kalman_filtering_non_scalar(inputs_batch[0, 2 * input_dim + 1:, :],
                                                                outputs_batch[0, 2 * input_dim + 1:], F=F, Q=Q, R=R,
                                                                chunk_size=chunk_size, y_dim=y_dim)

            out_SGD_0_pt_01 = Stochastic_Gradient_Descent_Regression_one_step_non_scalar(
                inputs_batch[:, 2 * input_dim + 1:, :], alpha=0.01, device='cuda', y_dim=y_dim)
            out_Ridge_0_pt_01 = Ridge_Regression_one_step_non_scalar(
                inputs_batch=inputs_batch[:, 2 * input_dim + 1:, :], lambda_=0.01, device='cuda', y_dim=y_dim)
            out_SGD_0_pt_05 = Stochastic_Gradient_Descent_Regression_one_step_non_scalar(
                inputs_batch[:, 2 * input_dim + 1:, :], alpha=0.05, device='cuda', y_dim=y_dim)
            out_Ridge_0_pt_05 = Ridge_Regression_one_step_non_scalar(
                inputs_batch=inputs_batch[:, 2 * input_dim + 1:, :], lambda_=0.05,
                device='cuda', y_dim=y_dim)
            if chunk_size >= 9:
                out_OLS = Ridge_Regression_one_step_non_scalar(
                    inputs_batch=inputs_batch[:, 2 * input_dim + 1:, :], lambda_=0.0,
                    device='cuda', y_dim=y_dim)
            with torch.no_grad():
                if drop_stats:
                    if drop_mode=='Noise':
                        out = trained_model(torch.cat((inputs_batch[:, :input_dim, :], inputs_batch[:,
                                                                                   2 * input_dim + 1:2 * chunk_size + 2 * input_dim + 1 - 1,
                                                                                   :]), dim=1))
                    elif drop_mode=='All':
                        out = trained_model(inputs_batch[:,2 * input_dim + 1:2 * chunk_size + 2 * input_dim + 1 - 1,:])

                else:
                    out = trained_model(inputs_batch[:, 0:2 * chunk_size + 2 * input_dim + 1 - 1, :])

            MSPD_list += [torch.sum(((torch.squeeze(out_filtering) - out[:, -1].cpu()) ** 2))]
            MSPD_list_SGD_0_pt_01 += [torch.sum((((out_SGD_0_pt_01.cpu() - out[:, -1].cpu()) ** 2)))]
            MSPD_list_Ridge_0_pt_01 += [torch.sum((((torch.squeeze(out_Ridge_0_pt_01.cpu()) - out[:, -1].cpu()) ** 2)))]
            MSPD_list_SGD_0_pt_05 += [torch.sum((((out_SGD_0_pt_05.cpu() - out[:, -1].cpu()) ** 2)))]
            MSPD_list_Ridge_0_pt_05 += [
                torch.sum((((torch.squeeze(out_Ridge_0_pt_05.cpu()) - out[:, -1].cpu()) ** 2)))]

            if chunk_size >= 9:
                MSPD_list_OLS += [
                    torch.sum((((torch.squeeze(out_OLS.cpu()) - out[:, -1].cpu()) ** 2)))]

        MSPD_list_cz += [numpy.mean(torch.tensor(MSPD_list).numpy() / 8)]
        MSPD_list_cz_SGD_0_pt_01 += [numpy.mean(torch.tensor(MSPD_list_SGD_0_pt_01).numpy() / 8)]
        MSPD_list_cz_SGD_0_pt_05 += [numpy.mean(torch.tensor(MSPD_list_SGD_0_pt_05).numpy() / 8)]
        MSPD_list_cz_Ridge_0_pt_01 += [numpy.mean(torch.tensor(MSPD_list_Ridge_0_pt_01).numpy() / 8)]
        MSPD_list_cz_Ridge_0_pt_05 += [numpy.mean(torch.tensor(MSPD_list_Ridge_0_pt_05).numpy() / 8)]
        if chunk_size >= 9:
            MSPD_list_cz_OLS += [numpy.mean(torch.tensor(MSPD_list_OLS).numpy() / 8)]

    MSPD_list_cz_param += [MSPD_list_cz];
    MSPD_list_cz_SGD_0_pt_01_param += [MSPD_list_cz_SGD_0_pt_01]
    MSPD_list_cz_SGD_0_pt_05_param += [MSPD_list_cz_SGD_0_pt_05]
    MSPD_list_cz_Ridge_0_pt_01_param += [MSPD_list_cz_Ridge_0_pt_01]
    MSPD_list_cz_Ridge_0_pt_05_param += [MSPD_list_cz_Ridge_0_pt_05]
    MSPD_list_cz_OLS_param += [MSPD_list_cz_OLS]

plt.clf()

for i in range(len(param_vec)):
    plt.figure()
    plt.plot(numpy.arange(2, 42), MSPD_list_cz_param[i])
    plt.plot(numpy.arange(2, 42), MSPD_list_cz_SGD_0_pt_01_param[i])
    plt.plot(numpy.arange(2, 42), MSPD_list_cz_Ridge_0_pt_01_param[i])
    plt.plot(numpy.arange(2, 42), MSPD_list_cz_SGD_0_pt_05_param[i])
    plt.plot(numpy.arange(2, 42), MSPD_list_cz_Ridge_0_pt_05_param[i])
    plt.plot(numpy.arange(9, 42), MSPD_list_cz_OLS_param[i])
    plt.legend(
        ['ICL and Kalman Filter', 'ICL and SGD 0.01', 'ICL and Ridge 0.01', 'ICL and SGD 0.05', 'ICL and Ridge 0.05',
         'ICL and OLS'])
    plt.xlabel('Context Length')
    plt.ylabel('1/n MSPD')
    plt.ylim([0.0, 1.0])
    filename_fig = 'SS_one_step_non_scalar_Option_' + str(option) + '_y_dim_' + str(y_dim) + '_param_q_' + str(param_vec[i]) + '_disc_' + str(drop_stats) + '_md_' + drop_mode + '.png'
    plt.savefig(filename_fig)

    list_to_save = [MSPD_list_cz_param[i], MSPD_list_cz_SGD_0_pt_01_param[i], MSPD_list_cz_Ridge_0_pt_01_param[i],
                    MSPD_list_cz_SGD_0_pt_05_param[i], MSPD_list_cz_Ridge_0_pt_05_param[i], MSPD_list_cz_OLS_param[i]]
    filename = 'SS_one_step_non_scalar_Option_'+str(option)+'_y_dim_' + str(y_dim) + '_param_q_' + str(param_vec[i]) + '_disc_'+str(drop_stats)+'_md_'+drop_mode+'.pkl'
    with open(filename, 'wb') as file:
        pickle.dump(list_to_save, file)



