from Requierments import *
from Datasets import *
from Analytical_MTL import *
from MTL_Neural_Networks import *
from Balanced_Inits import *


def run_multitask_training(
        input_dim=10,
        task1_dim=20,
        task2_dim=20,
        shared_dim=5,
        sigmaweights=0.01,
        sigmaweights_TS=0.025,
        sigmaX=1,
        learning_rate=0.01,
        batchsize=100,
        epochs=150,
        noise2=0,
        rho=0,
        scale_1=1,
        scale_2=1, seed=300, same_task_init=False, aligned=True, ortho=False, teacher_student=False, ortho_sim=False, ali_sim=False, con_sim=False,
        random_regression=False, regression_sigma=0, regression_norm=False, regression_ali=True, rememberweights=True, onlyfinalweights=False, deeperMTL=False, nshared=1, ntaskL=1,
        onlyNN=False, TW2=1, Mnist=False, shift=4, non_linear=False, STL=False, Task1=False, MTL_regression=False, Alignment_factor=1
):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # --- Generate Input Data ---
    X_train = torch.normal(0, sigmaX, size=(batchsize, input_dim))
    X_train = whiten_input(X_train)
    # print(f"X_train shape: {X_train.shape}")

    # Create task depending on experiment
    if teacher_student:
        shared_dim = min(input_dim, task1_dim + task2_dim)
        # --- Ground Truth Weights ---
        WS_TS, W1_TS, W2_TS = zero_balanced_MTL(sigmaweights_TS, input_dim, shared_dim, task1_dim, task2_dim,
                                                aligned=aligned, ortho=ortho)
        #WS_TS, W1_TS, W2_TS=generate_teacher_student_tasks(input_dim=input_dim, shared_dim=shared_dim, output_dim_1=task1_dim, output_dim_2=task2_dim, aligned=aligned, num_samples=batchsize)

        y1_train = scale_1 * W1_TS @ WS_TS @ np.array(X_train.T)
        y2_train = scale_2 * W2_TS @ WS_TS @ np.array(X_train.T)

        Sigma_1o = (1 / batchsize) * y1_train @ np.array(X_train)
        Sigma_2o = (1 / batchsize) * y2_train @ np.array(X_train)
        sigma1o_flat = Sigma_1o.flatten()
        sigma2o_flat = Sigma_2o.flatten()
        cosine_similarity = np.dot(sigma1o_flat, sigma2o_flat) / (
                    np.linalg.norm(sigma1o_flat) * np.linalg.norm(sigma2o_flat))
        print("cossim", cosine_similarity)
        U1,S1,V1=np.linalg.svd(Sigma_1o, full_matrices=False)
        U2,S2,V2=np.linalg.svd(Sigma_2o, full_matrices=False)
        O=V1.T@V2
        UO, SO, VO=np.linalg.svd(O, full_matrices=False)
        print("SALI", SO)
        print("magdif", np.linalg.norm(sigma1o_flat)/np.linalg.norm(sigma1o_flat))
        
    elif ali_sim:
        d = input_dim # corresponds to dim taskoutput. ideally both tasks have same dim. + inputdim=taskinptdim
        r = int(d/2)
        rng = np.random.default_rng()
        #print(r)
        U1, _ = np.linalg.qr(rng.standard_normal((task1_dim, r)))  # d x r
        V1, _= np.linalg.qr(np.random.randn(d, r)) 
        U2=U1
        V2=V1
        eigval_1 = np.flip(np.sort(np.random.rand(1, r))).flatten()
        eigval_2 = np.flip(np.sort(np.random.rand(1, r))).flatten()
        #print(U1.shape)
        #print(V1.shape)
        #eigval_1=[4,1.5]
        #eigval_2=[3, 0.5]
        S1= np.diag(eigval_1)
        S2 = np.diag(eigval_2)
        #print(S2)
        Sigma_1o = U1@ S1 @ V1.T
        Sigma_2o = U2 @ S2 @ V2.T
        #Sigma_1o = Sigma_1o / np.linalg.norm(Sigma_1o, "fro")
        #Sigma_2o= Sigma_2o / np.linalg.norm(Sigma_2o, "fro")
        O=V1.T@ V2
        UO, SO, VO=np.linalg.svd(O, full_matrices=False)
        print("SALI", SO)
        XI=np.concatenate((Sigma_1o, Sigma_2o), axis=0)
        U_Xi, S_Xi, Vt_Xi=np.linalg.svd(XI, full_matrices=False)
        #print(S_Xi)

    elif ortho_sim:
        d = input_dim # corresponds to dim taskoutput. ideally both tasks have same dim. + inputdim=taskinptdim
        r = int(d/2)
        #print(r)
        rng = np.random.default_rng()
        U_full, _ = np.linalg.qr(rng.standard_normal((task1_dim+task2_dim, r)))
        #print(U_full.shape)
        U1 = U_full[:task1_dim, :]
        U2 = U_full[task1_dim:task1_dim+task2_dim, :]
        V_full, _ = np.linalg.qr(np.random.randn(d, d))
        V1=V_full[:, :r]
        V2=V_full[:, r:d]
        #print(V1.shape)
        #print(V2.shape)
        #print(U1.shape)
        #print(U2.shape)
        #eigvals_orth_1 = np.flip(np.sort(np.random.rand(1, r)))
        #eigvals_orth_2 = np.flip(np.sort(np.random.rand(1, r)))
        S1=np.diag([4,1.5])
        S2=np.diag([2, 0.5])
        Sigma_1o = U1 @ S1 @ V1.T
        Sigma_2o = U2 @ S2 @ V2.T
        O=V1.T@ V2
        UO, SO, VO=np.linalg.svd(O, full_matrices=False)
        print("SALI", SO)
        XI=np.concatenate((Sigma_1o, Sigma_2o), axis=0)
        U_Xi, S_Xi, Vt_Xi=np.linalg.svd(XI, full_matrices=False)
        
    elif con_sim: 
        d = input_dim # corresponds to dim taskoutput. ideally both tasks have same dim. + inputdim=taskinptdim
        r = int(d/2)
        rng = np.random.default_rng()
        #print(r)
        U1, _ = np.linalg.qr(rng.standard_normal((task1_dim, r)))  # d x r
        V1, _= np.linalg.qr(np.random.randn(d, r)) 
        U2, _ =  np.linalg.qr(rng.standard_normal((task1_dim, r)))
        V2, _= np.linalg.qr(np.random.randn(d, r)) 
        #eigval_1 = np.flip(np.sort(np.random.rand(1, r))).flatten()
        #eigval_2 = np.flip(np.sort(np.random.rand(1, r))).flatten()
        #print(eigval_1)
        #print(eigval_2)
        #print(U1.shape)
        #print(V1.shape)
        eigval_1=[4,3]
        eigval_2=[2, 1]
        S1= np.diag(eigval_1)
        S2 = np.diag(eigval_2)
        #print(S2)
        Sigma_1o = U1@ S1 @ V1.T
        Sigma_2o = U2 @ S2 @ V2.T
        Sigma_1o = Sigma_1o / np.linalg.norm(Sigma_1o, "fro")
        Sigma_2o= Sigma_2o / np.linalg.norm(Sigma_2o, "fro")
        O=V1.T@ V2
        UO, SO, VO=np.linalg.svd(O, full_matrices=False)
        print("SALI", SO)
        XI=np.concatenate((Sigma_1o, Sigma_2o), axis=0)
        U_Xi, S_Xi, Vt_Xi=np.linalg.svd(XI, full_matrices=False)  
    
    elif random_regression:
        # shared_dim=min(input_dim, task1_dim+task2_dim) shared dim should be bigger thant his.
        X_train, y1_train, y2_train = generate_regression_tasks(input_dim=input_dim, task1_dim=task1_dim,
                                                                task2_dim=task2_dim, num_samples=batchsize, rho=rho,
                                                                alphas=(scale_1, scale_2), noise_std=noise2, seed=seed, sigma_eps=regression_sigma, normalize=regression_norm)

        Sigma_1o = y1_train.T @ X_train / batchsize  # (output_dim, input_dim)
        Sigma_2o = y2_train.T @ X_train / batchsize

        U1, S1, V1 = np.linalg.svd(Sigma_1o, full_matrices=False)
        U2, S2, V2 = np.linalg.svd(Sigma_2o, full_matrices=False)

        U, S, V = np.linalg.svd(V1 @ V2.T)
        X_train = torch.tensor(X_train).float()
        y1_train = torch.tensor(y1_train.T).float()
        y2_train = torch.tensor(y2_train.T).float()

        print("alignment valuee", S)
        
    elif Mnist:
        mnist = fetch_openml('mnist_784', version=1, as_frame=False)
        mnistdata = mnist['data'].astype(np.uint8)  # shape (70000, 784)
        mnistlabels = mnist['target'].astype(int)
        gen_base = MultiMNISTGenerator(mnistdata, mnistlabels, shift_range=shift, alignment=aligned)
            # “permuted” version (alignment=False)
        # gen_permuted = MultiMNISTGenerator(data, labels, shift_range=4, alignment=False)
        X_train, y1_train, y2_train = gen_base.batchmaker(batchsize, gen_base)
        Sigma_1o = ((torch.tensor(y1_train).T).float()) @ X_train / batchsize  # (output_dim, input_dim)
        Sigma_2o = ((torch.tensor(y2_train)).float()).T @ X_train / batchsize
        U1, S1, V1 = np.linalg.svd(Sigma_1o, full_matrices=False)
        U2, S2, V2 = np.linalg.svd(Sigma_2o, full_matrices=False)
        XI = np.concatenate((Sigma_1o, Sigma_2o), axis=0)

        Sigma_1o = y1_train.T @ X_train / batchsize  # (output_dim, input_dim)
        Sigma_2o = y2_train.T @ X_train / batchsize

        U1, S1, V1 = np.linalg.svd(Sigma_1o, full_matrices=False)
        U2, S2, V2 = np.linalg.svd(Sigma_2o, full_matrices=False)

        U, S, V = np.linalg.svd(V1 @ V2.T)
        input_dim = X_train.shape[1]
        task1_dim = y1_train.shape[1]
        task2_dim = y2_train.shape[1]
        #shared_dim = min(input_dim, task1_dim + task2_dim)
        shared_dim=shared_dim
        X_train = torch.tensor(X_train).float()
        y1_train = y1_train.T
        y2_train = y2_train.T
    
    elif MTL_regression == True: 
        X, y1_train, y2_train= MTL_tasks(input_dim=input_dim, task1_dim=task1_dim, task2_dim=task2_dim, num_samples=batchsize, alpha_scales=(1,scale_2), seed=seed, alignment_factor=Alignment_factor)
        
        Sigma_1o = y1_train @ X/ batchsize # (output_dim, input_dim)
        Sigma_2o = y2_train @ X / batchsize

        U1, S1, V1=np.linalg.svd(Sigma_1o, full_matrices=False)
        U2, S2, V2=np.linalg.svd(Sigma_2o, full_matrices=False)
        print(S1)
        print(S2)
        U, S, V=np.linalg.svd(V1@V2.T)
        print("alignment valuee", S)
        
        

    # --- Init Weights for Training ---
    WS_init, W1_init, W2_init = zero_balanced_MTL(sigmaweights, input_dim, shared_dim, task1_dim, task2_dim) #put on false for other experiments
    if same_task_init == True:
        W1_init = W2_init
    if onlyNN == False:
        # --- Ground Truth QQᵀ ---
        if ali_sim or ortho_sim or con_sim:
          analyticals = np.zeros((epochs, input_dim + task1_dim + task2_dim, input_dim + task1_dim + task2_dim))
          analytical = QQT_MTL(WS_init, W1_init, W2_init, X_train, Sigma_1o, Sigma_2o, weightsonly=False,  U1=U1, U2=U2, S1=S1, S2=S2, V1=V1, V2=V2)

          UXi, SXi, VXi = analytical.return_SVDxi()
        else:
          analyticals = np.zeros((epochs, input_dim + task1_dim + task2_dim, input_dim + task1_dim + task2_dim))
          analytical = QQT_MTL(WS_init, W1_init, W2_init, X_train, Sigma_1o, Sigma_2o, weightsonly=False)

          UXi, SXi, VXi = analytical.return_SVDxi()

    # --- Containers for monitoring ---
    WSWST = []
    W1W1T = []
    W2W2T = []

    W1WS = []
    W2WS = []
    # remaining of row 1
    WSTW1T = []
    WSTW2T = []

    # remaining of row2
    W1WS = []
    W1W2T = []

    # remaining of row 3
    W2WS = []
    W2W1T = []

    tracking_Q = []

    loss_task1_history, loss_task2_history, total_loss_history = [], [], []

    # --- Model, Loss ---

    if teacher_student or random_regression or Mnist or MTL_regression:
        criterion = nn.MSELoss()
        if STL==True:
          model= SingleTaskNetwork(input_dim, shared_dim, task1_dim,sigmaweights,  W1_init=W1_init, W2_init=W2_init, WS_init=WS_init)
        elif deeperMTL == False and non_linear==False:
            model = MultiTaskNetwork(input_dim, shared_dim, task1_dim, task2_dim, sigmaweights,
                                     W1_init=W1_init, W2_init=W2_init, WS_init=WS_init)
        elif deeperMTL == True and non_linear==False:
            model = DeepMultiTaskNetwork(input_dim, shared_dim, task1_dim, task2_dim, nshared, ntaskL, sigmaweights)
        elif deeperMTL == False and non_linear == True:
            model=NonLinearMultiTaskNetwork(input_dim, shared_dim, task1_dim, task2_dim, sigmaweights,
                                     W1_init=W1_init, W2_init=W2_init, WS_init=WS_init)
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0)

        # --- Training Loop ---
        for epoch in range(epochs):
            model.train()
            if STL==False:
              task1_pred, task2_pred = model(X_train)
              loss_task1 = criterion(task1_pred, torch.tensor(y1_train, dtype=torch.float32).T)
              loss_task2 = criterion(task2_pred, torch.tensor(y2_train, dtype=torch.float32).T)
              total_loss = 0.5 * loss_task1 + TW2*0.5 * loss_task2

            # Store loss history
              loss_task1_history.append(loss_task1.item())
              loss_task2_history.append(loss_task2.item())
              total_loss_history.append(total_loss.item())

            else:
              if Task1==True:
                task1_pred=model(X_train)
                loss_task1 = criterion(task1_pred, torch.tensor(y1_train, dtype=torch.float32).T)
                total_loss =  loss_task1 
                loss_task1_history.append(loss_task1.item())
                total_loss_history.append(total_loss.item())
              else:
                task2_pred=model(X_train)
                loss_task1 = criterion(task2_pred, torch.tensor(y2_train, dtype=torch.float32).T)
                total_loss =  loss_task1
                loss_task1_history.append(loss_task1.item())
                total_loss_history.append(total_loss.item())

                

            # print(task1_pred.shape)
            # print(y1_train.T.shape)
            
            # Gradient descent
            # model.zero_grad()
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            if rememberweights == True or epoch == epochs - 1:
                with torch.no_grad():
                    if STL == True:
                        WS=model.shared_layer.weight
                        W1 = model.task1_layer.weight
                    elif deeperMTL == False:
                        WS = model.shared_layer.weight
                        W1 = model.task1_layer.weight
                        W2 = model.task2_layer.weight
                    else:
                        WS = model.sharedLayers[nshared - 1].weight
                        W1 = model.taskLayers[0][ntaskL - 1].weight
                        W2 = model.taskLayers[1][ntaskL - 1].weight

                    WSWST.append(np.asarray(WS.T @ WS))
                    W1W1T.append(np.asarray(W1 @ W1.T))
                    if STL==False:
                      W2W2T.append(np.asarray(W2 @ W2.T))

                    if ntaskL == 1 and STL==False:
                        W1WS.append(np.asarray(W1 @ WS))
                        W2WS.append(np.asarray(W2 @ WS))

                        WSTW1T.append(np.asarray(WS.T @ W1.T))
                        WSTW2T.append(np.asarray(WS.T @ W2.T))

                        W1W2T.append(np.asarray(W1 @ W2.T))
                        W2W1T.append(np.asarray(W2 @ W1.T))

                        # Calculate the left-hand side (LHS) and right-hand side (RHS)
                        LL = np.asarray(WS @ WS.T)
                        RL = np.asarray(W1.T @ W1 + W2.T @ W2)
                        diff_norm = np.linalg.norm(LL - RL)
                        tracking_Q.append(diff_norm)
                    elif STL==True:
                          W1WS.append(np.asarray(W1 @ WS))
                          WSTW1T.append(np.asarray(WS.T @ W1.T))
                          
            if onlyNN == False:
                # Update analytical evolution
                analyticals[epoch] = analytical.time_step_nondiag(learning_rate)
    else:
        for epoch in range(epochs):
            if onlyNN == False:
                analyticals[epoch] = analytical.time_step_nondiag(learning_rate)

    if teacher_student or random_regression or Mnist or MTL_regression:
        if rememberweights or epoch == epochs - 1:
            A = WSWST[-1]  # shape (input_dim, input_dim)
            if ntaskL == 1 and STL==False:
                B = WSTW1T[-1]  # shape (input_dim, task1_dim)
                C = WSTW2T[-1]  # shape (input_dim, task2_dim)
                F = W1W2T[-1]  # shape (task1_dim, task2_dim)

                G = W2WS[-1]  # shape (task2_dim, input_dim)
                H = W2W1T[-1]
                D = W1WS[-1]  # shape (task1_dim, input_dim)

            E = W1W1T[-1]  # shape (task1_dim, task1_dim)
            # shape (task2_dim, task1_dim)
            if STL==False:
              I = W2W2T[-1]  # shape (task2_dim, task2_dim)
            if ntaskL == 1 and STL==False:
                # 2) Assemble into one block‐matrix
                
                top = np.hstack([A, B, C])
                middle = np.hstack([D, E, F])
                bottom = np.hstack([G, H, I])

                QQ_NN_final = np.vstack([top, middle, bottom])
            elif STL==True:
                B = WSTW1T[-1]
                D = W1WS[-1]
                top=np.hstack([A,B])
                middle=np.hstack([D,E])
                QQ_NN_final=np.vstack([top, middle])
            else:
                QQ_NN_final = [A, E, I]
    if onlyNN == False:
        QQ_ana_final = analyticals[-1]

    if rememberweights or epoch == epochs - 1:
        # Reshape for plot
        if deeperMTL == True and nshared > 1:
            wswst_plot = np.array(WSWST).reshape(-1, shared_dim * shared_dim)
            w1tw1_plot = np.array(W1W1T).reshape(-1, task1_dim * task1_dim)
            w2tw2_plot = np.array(W2W2T).reshape(-1, task2_dim * task2_dim)
            wstw1t_plot = np.array(WSTW1T).reshape(-1,
                                                   shared_dim * task1_dim)  # not sure how to interpret thes. might skip.
            wstw2t_plot = np.array(WSTW2T).reshape(-1, shared_dim * task2_dim)
            w1ws_plot = np.array(W1WS).reshape(-1, task1_dim * shared_dim)  # output function task 1
            w2ws_plot = np.array(W2WS).reshape(-1, task2_dim * shared_dim)  # output function task 2
            w1w2t_plot = np.array(W1W2T).reshape(-1, task1_dim * task2_dim)  # correlation task 1 with two
            w2w1t_plot = np.array(W2W1T).reshape(-1, task2_dim * task1_dim)
        elif STL==False:
            wswst_plot = np.array(WSWST).reshape(-1, input_dim * input_dim)
            w1tw1_plot = np.array(W1W1T).reshape(-1, task1_dim * task1_dim)
            w2tw2_plot = np.array(W2W2T).reshape(-1, task2_dim * task2_dim)
            wstw1t_plot = np.array(WSTW1T).reshape(-1,
                                                   input_dim * task1_dim)  # not sure how to interpret thes. might skip.
            wstw2t_plot = np.array(WSTW2T).reshape(-1,
                                                   input_dim * task2_dim)  # not sure how to interpret thes. might skip.
            w1ws_plot = np.array(W1WS).reshape(-1, task1_dim * input_dim)  # output function task 1
            w2ws_plot = np.array(W2WS).reshape(-1, input_dim * task2_dim)  # output function task 2
            w1w2t_plot = np.array(W1W2T).reshape(-1, task1_dim * task2_dim)  # correlation task 1 with two
            w2w1t_plot = np.array(W2W1T).reshape(-1, task2_dim * task1_dim)  # correlation task 2 with 1
        elif STL==True:
            wswst_plot = np.array(WSWST).reshape(-1, input_dim * input_dim)
            w1tw1_plot = np.array(W1W1T).reshape(-1, task1_dim * task1_dim)
            wstw1t_plot = np.array(WSTW1T).reshape(-1,input_dim * task1_dim)
            w1ws_plot = np.array(W1WS).reshape(-1, task1_dim * input_dim)
        # diags
        if onlyNN == False:
            WSWST_ana = np.array(analyticals[:, :input_dim, :input_dim]).reshape(-1, input_dim * input_dim)
            W1W1T_ana = np.array(
                analyticals[:, input_dim: input_dim + task1_dim, input_dim: input_dim + task1_dim]).reshape(-1,
                                                                                                            task1_dim * task1_dim)
            W2W2T_ana = np.array(analyticals[:, input_dim + task1_dim: input_dim + task1_dim + task2_dim,
                                 input_dim + task1_dim: input_dim + task1_dim + task2_dim]).reshape(-1,
                                                                                                    task2_dim * task2_dim)
            # remaining colom 1
            W1WS_ana = np.array(analyticals[:, input_dim: input_dim + task1_dim, :input_dim]).reshape(-1,
                                                                                                      task1_dim * input_dim)
            W2WS_ana = np.array(
                analyticals[:, input_dim + task1_dim: input_dim + task1_dim + task2_dim, :input_dim]).reshape(-1,
                                                                                                              task2_dim * input_dim)
            # intertaskcorrelations
            W2W1T_ana = np.array(analyticals[:, input_dim + task1_dim:input_dim + task1_dim + task2_dim,
                                 input_dim: input_dim + task1_dim]).reshape(-1, task2_dim * task1_dim)
            W1W2T_ana = np.array(analyticals[:, input_dim: input_dim + task1_dim,
                                 input_dim + task1_dim: input_dim + task1_dim + task2_dim]).reshape(-1,
                                                                                                    task1_dim * task2_dim)
            # remaining row 1
            WSTW1T_ana = np.array(analyticals[:, :input_dim, input_dim: input_dim + task1_dim]).reshape(-1,
                                                                                                        input_dim * task1_dim)
            WSTW2T_ana = np.array(
                analyticals[:, :input_dim, input_dim + task1_dim: input_dim + task1_dim + task2_dim]).reshape(-1,
                                                                                                              input_dim * task2_dim)

            analytical_plots = [WSWST_ana, WSTW1T_ana, WSTW2T_ana, W1WS_ana, W1W1T_ana, W1W2T_ana, W2WS_ana, W2W1T_ana,
                                W2W2T_ana]
        if STL==False:
          NN_plots = [
            (wswst_plot, "WS.T @ WS"),
            (wstw1t_plot, "WS.T @ W1.T"),
            (wstw2t_plot, "WS.T @ W2.T"),
            (w1ws_plot, "W1 @ WS"),
            (w1tw1_plot, "W1 @ W1.T"),
            (w1w2t_plot, "W1 @ W2.T"),
            (w2ws_plot, "W2 @ WS"),
            (w2w1t_plot, "W2 @ W1.T"),
            (w2tw2_plot, "W2 @ W2.T")]
        else:
          NN_plots = [
            (wswst_plot, "WS.T @ WS"),
            (wstw1t_plot, "WS.T @ W1.T"),
            (w1ws_plot, "W1 @ WS"),
            (w1tw1_plot, "W1 @ W1.T")]
          

    if teacher_student or random_regression or Mnist or MTL_regression:
        if onlyNN == False and STL==False:
            loss_task1_ana, loss_task2_ana, loss_total_ana = compute_analytical_loss_from_plots(analytical_plots,
                                                                                                X_train, y1_train,
                                                                                                y2_train, input_dim,
                                                                                                task1_dim, task2_dim,
                                                                                                epochs)
            return NN_plots, analytical_plots, loss_task1_history, loss_task2_history, total_loss_history, loss_task1_ana, loss_task2_ana, loss_total_ana, tracking_Q, Sigma_1o, Sigma_2o, QQ_NN_final, QQ_ana_final, UXi, SXi, VXi
        elif onlyNN==True and STL==False:#onlyNN and MTL
            return NN_plots, loss_task1_history, loss_task2_history, total_loss_history, Sigma_1o, Sigma_2o, QQ_NN_final
        else: #stl ==true
            return NN_plots, loss_task1_history,total_loss_history, Sigma_1o, QQ_NN_final
          
        

    else:
        return analytical_plots, Sigma_1o, Sigma_2o, QQ_ana_final, UXi, SXi, VXi



