import os
from random import randint
import uuid

from quinine import QuinineArgumentParser
from tqdm import tqdm
import torch
import yaml

from eval import get_run_metrics
from tasks import get_task_sampler
from samplers import get_data_sampler
from curriculum import Curriculum
from schema import schema
from models import build_model
from tasks import mean_squared_error, mean_squared_error_state, mean_squared_error_measurement
from tasks import squared_error
import wandb
from scipy.stats import ortho_group

torch.backends.cudnn.benchmark = True


def train_step(model, xs, ys, optimizer, loss_func):
    optimizer.zero_grad()
    output = model(xs, ys)
    loss = loss_func(output, ys)
    loss.backward()
    optimizer.step()
    return loss.detach().item(), output.detach()

def train_step_mine(model, inputs_batch, outputs_batch, optimizer, loss_func):
    optimizer.zero_grad()
    output = model(inputs_batch)
    # print(inputs_batch.shape)
    # print(output.shape)
    # print(outputs_batch.shape)
    loss = loss_func(output, outputs_batch[:,::2])
    loss.backward()
    optimizer.step()
    return loss.detach().item(), output.detach()


def train_step_mine_SS(model, inputs_batch, outputs_batch, optimizer, loss_func):
    optimizer.zero_grad()
    output = model(inputs_batch)
    loss = loss_func(output, outputs_batch[:,(model.n_dims-1)::2])
    loss.backward()
    optimizer.step()
    return loss.detach().item(), output.detach()


def train_step_mine_SS_innovation_noise(model, inputs_batch, outputs_batch, optimizer, loss_func):
    optimizer.zero_grad()
    output = model(inputs_batch)
    loss = loss_func(output, outputs_batch[:,2*(model.n_dims-1)::2])
    loss.backward()
    optimizer.step()
    return loss.detach().item(), output.detach()

def train_step_mine_SS_innovation_noise_obs_noise(model, inputs_batch, outputs_batch, optimizer, loss_func):
    optimizer.zero_grad()
    output = model(inputs_batch)
    loss = loss_func(output, outputs_batch[:,2*(model.n_dims-1)+1::2])
    loss.backward()
    optimizer.step()
    return loss.detach().item(), output.detach()

def train_step_mine_SS_innovation_noise_obs_noise_non_scalar_y(model, inputs_batch, outputs_batch, optimizer, loss_func, y_dim=2):
    optimizer.zero_grad()
    output = model(inputs_batch)
    loss = loss_func(output, outputs_batch[:,2*(model.state_dim)+1::2,:], y_dim)
    loss.backward()
    optimizer.step()
    return loss.detach().item(), output.detach()


def train_step_mine_SS_innovation_noise_obs_noise_state_est_curr(model, inputs_batch, outputs_batch, optimizer, loss_func):
    optimizer.zero_grad()
    output = model(inputs_batch)
    loss = loss_func(output, outputs_batch[:,2*(model.n_dims-1)+2::2,:])
    loss.backward()
    optimizer.step()
    return loss.detach().item(), output.detach()



def train_step_one_step_pred(model, inputs_batch, outputs_batch, optimizer, loss_func, y_dim=2, discard=True, discard_mode='All'):
    optimizer.zero_grad()
    output = model(inputs_batch)
    if not discard:
        loss = loss_func(output, outputs_batch[:,2*(model.state_dim)+1::2,:], y_dim)
    elif discard and discard_mode=='All':
        loss = loss_func(output, outputs_batch[:, ::2, :], y_dim)
    elif discard and discard_mode=='Noise':
        loss = loss_func(output, outputs_batch[:, model.state_dim::2, :], y_dim)

    loss.backward()
    optimizer.step()
    return loss.detach().item(), output.detach()


def train_step_one_step_pred_control(model, inputs_batch, outputs_batch, optimizer, loss_func, y_dim=2, discard=True, discard_mode='All', control=True):
    optimizer.zero_grad()
    output = model(inputs_batch)

    if control:
        if not discard:
            loss = loss_func(output, outputs_batch[:,3*(model.state_dim)+1::3,:], y_dim)
        elif discard and discard_mode=='All':
            loss = loss_func(output, outputs_batch[:, ::3, :], y_dim)
        elif discard and discard_mode=='Noise':
            loss = loss_func(output, outputs_batch[:, 2*model.state_dim::3, :], y_dim)
    else:
        if not discard:
            loss = loss_func(output, outputs_batch[:,2*(model.state_dim)+1::2,:], y_dim)
        elif discard and discard_mode=='All':
            loss = loss_func(output, outputs_batch[:, ::2, :], y_dim)
        elif discard and discard_mode=='Noise':
            loss = loss_func(output, outputs_batch[:, model.state_dim::2, :], y_dim)

    loss.backward()
    optimizer.step()
    return loss.detach().item(), output.detach()

def sample_seeds(total_seeds, count):
    seeds = set()
    while len(seeds) < count:
        seeds.add(randint(0, total_seeds - 1))
    return seeds


def Gen_data(device='cuda', batch_size=64, eval=False, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, noise_sigma=0, d_curr=8):
    w_eval = w_sigma*torch.randn((input_dim, 1));
    for i in range(batch_size):
        w_train = w_sigma*torch.randn((input_dim, 1));
        if eval:
            w=w_eval
        else:
            w=w_train
        x = x_sigma*torch.randn((input_dim, chunk_size))
        if d_curr < input_dim:
            x[(d_curr - 1):, :] = 0.0;
        y = torch.matmul(torch.transpose(w, 0, 1), x)
        y += noise_sigma*torch.randn(y.shape);
        inputs = torch.zeros((input_dim + 1, 2 * chunk_size))
        inputs[1:input_dim + 1, 0:2 * chunk_size:2] = x
        inputs[0, 1:2 * chunk_size:2] = y
        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    outputs_batch = torch.zeros((batch_size, 2 * chunk_size))
    outputs_batch[:, 0:2 * chunk_size:2] = inputs_batch[:, 0, 1:2 * chunk_size:2]


    inputs_batch = torch.transpose(inputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)
    return inputs_batch, outputs_batch

def Gen_data_SS(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0):
    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
        else:
            F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)

        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)
        w = torch.unsqueeze(w_t_m_1, dim=0);

        for k in range(int(chunk_size) - 1):
            w_t_m_1 = torch.matmul(F, w_t_m_1);
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)

        x = x_sigma * torch.randn((input_dim, chunk_size), dtype=float)

        if d_curr < input_dim:
            x[(d_curr - 1):, :] = 0.0;
        y = torch.diag(torch.matmul(torch.squeeze(w, dim=-1), x));
        inputs = torch.zeros((input_dim + 1, 2 * chunk_size + input_dim))
        inputs[1:input_dim + 1, 0:input_dim] = F
        inputs[1:input_dim + 1, input_dim:2 * chunk_size + input_dim:2] = x

        inputs[0, input_dim + 1:2 * chunk_size + input_dim:2] = y
        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    outputs_batch = torch.zeros((batch_size, 2 * chunk_size + input_dim))
    outputs_batch[:, input_dim:2 * chunk_size + input_dim:2] = inputs_batch[:, 0,
                                                               input_dim + 1:2 * chunk_size + input_dim:2]

    inputs_batch = torch.transpose(inputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)
    return inputs_batch, outputs_batch





def Gen_data_SS_innovation_noise(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0, alpha_Q=0.0):
    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
            A_Q = 0.0*torch.eye(input_dim, dtype=float);
        else:
            F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)
            Sigma_Q_sqrt=torch.sqrt(alpha_Q*torch.diag(torch.rand((input_dim,),dtype=float)))
            Q_U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
            A_Q=torch.matmul(Q_U, Sigma_Q_sqrt);
            Q=torch.matmul(A_Q,A_Q.T)


        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)
        w = torch.unsqueeze(w_t_m_1, dim=0);

        for k in range(int(chunk_size) - 1):
            innovation_noise=torch.matmul(A_Q,torch.randn((input_dim,1), dtype=float));
            w_t_m_1 = torch.matmul(F, w_t_m_1)+innovation_noise;
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)

        x = x_sigma * torch.randn((input_dim, chunk_size), dtype=float)

        if d_curr < input_dim:
            x[(d_curr - 1):, :] = 0.0;
        y = torch.diag(torch.matmul(torch.squeeze(w, dim=-1), x));
        inputs = torch.zeros((input_dim + 1, 2 * chunk_size + 2*input_dim))
        inputs[1:input_dim + 1, 0:input_dim] = F
        inputs[1:input_dim + 1, input_dim:2*input_dim] = Q
        inputs[1:input_dim + 1, (2*input_dim):(2 * chunk_size + 2*input_dim):2] = x

        inputs[0, (2*input_dim + 1):(2 * chunk_size + 2*input_dim):2] = y
        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    outputs_batch = torch.zeros((batch_size, 2 * chunk_size + 2*input_dim))
    outputs_batch[:, (2*input_dim):(2 * chunk_size + 2*input_dim):2] = inputs_batch[:, 0,
                                                                       (2*input_dim + 1):(2 * chunk_size + 2*input_dim):2]

    inputs_batch = torch.transpose(inputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)
    return inputs_batch, outputs_batch


def Gen_data_SS_innovation_noise_obs_noise(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0, alpha_Q=0.0, alpha_R=0.0):

    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
            A_Q = 0.0*torch.eye(input_dim, dtype=float);
        else:
            F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)
            Sigma_Q_sqrt=torch.sqrt(alpha_Q*torch.diag(torch.rand((input_dim,),dtype=float)))
            Q_U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
            A_Q=torch.matmul(Q_U, Sigma_Q_sqrt);
            Q=torch.matmul(A_Q,A_Q.T)


        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)
        w = torch.unsqueeze(w_t_m_1, dim=0);

        for k in range(int(chunk_size) - 1):
            innovation_noise=torch.matmul(A_Q,torch.randn((input_dim,1), dtype=float));
            w_t_m_1 = torch.matmul(F, w_t_m_1)+innovation_noise;
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)

        x = x_sigma * torch.randn((input_dim, chunk_size), dtype=float)
        noise_var=alpha_R*torch.rand((1,), dtype=float)
        obs_noise=torch.sqrt(noise_var)*torch.randn((chunk_size,), dtype=float)


        if d_curr < input_dim:
            x[(d_curr - 1):, :] = 0.0;
        y = torch.diag(torch.matmul(torch.squeeze(w, dim=-1), x))+obs_noise;
        inputs = torch.zeros((input_dim + 1, 2 * chunk_size + 2*input_dim+1))

        inputs[1:input_dim + 1, 0:input_dim] = F
        inputs[1:input_dim + 1, input_dim:2*input_dim] = Q
        inputs[0, 2*input_dim+1] = noise_var;
        inputs[1:input_dim + 1, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = x

        inputs[0, (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2] = y
        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    outputs_batch = torch.zeros((batch_size, 2 * chunk_size + 2*input_dim+1))
    outputs_batch[:, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = inputs_batch[:, 0,
                                                                       (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2]

    inputs_batch = torch.transpose(inputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)
    return inputs_batch, outputs_batch






def Gen_data_SS_innovation_noise_obs_noise_state_est_curr(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0, alpha_Q=0.0, alpha_R=0.0):


    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
            A_Q = 0.0*torch.eye(input_dim, dtype=float);
        else:
            F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)
            Sigma_Q_sqrt=torch.sqrt(alpha_Q*torch.diag(torch.rand((input_dim,),dtype=float)))
            Q_U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
            A_Q=torch.matmul(Q_U, Sigma_Q_sqrt);
            Q=torch.matmul(A_Q,A_Q.T)


        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)
        w = torch.unsqueeze(w_t_m_1, dim=0);

        for k in range(int(chunk_size) - 1):
            innovation_noise=torch.matmul(A_Q,torch.randn((input_dim,1), dtype=float));
            w_t_m_1 = torch.matmul(F, w_t_m_1)+innovation_noise;
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)

        x = x_sigma * torch.randn((input_dim, chunk_size), dtype=float)
        noise_var=alpha_R*torch.rand((1,), dtype=float)
        obs_noise=torch.sqrt(noise_var)*torch.randn((chunk_size,), dtype=float)

        if i==0:
            w_batch=torch.unsqueeze(torch.squeeze(torch.transpose(w, 0,1), dim=-1), dim=0);
        else:
            w_batch=torch.concat((w_batch,torch.unsqueeze(torch.squeeze(torch.transpose(w, 0,1), dim=-1), dim=0)), dim=0)


        if d_curr < input_dim:
            x[(d_curr - 1):, :] = 0.0;
        y = torch.diag(torch.matmul(torch.squeeze(w, dim=-1), x))+obs_noise;
        inputs = torch.zeros((input_dim + 1, 2 * chunk_size + 2*input_dim+1))

        inputs[1:input_dim + 1, 0:input_dim] = F
        inputs[1:input_dim + 1, input_dim:2*input_dim] = Q
        inputs[0, 2*input_dim+1] = noise_var;
        inputs[1:input_dim + 1, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = x

        inputs[0, (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2] = y
        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    outputs_batch = torch.zeros((batch_size,input_dim, 2 * chunk_size + 2*input_dim+1))
    outputs_batch[:, 0:, (2*input_dim+2):(2 * chunk_size + 2*input_dim+1):2] = w_batch

    inputs_batch = torch.transpose(inputs_batch, 1, 2);
    outputs_batch = torch.transpose(outputs_batch, 1, 2);


    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)
    return inputs_batch, outputs_batch






def train(model, args):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
    curriculum = Curriculum(args.training.curriculum)

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()

    n_dims = model.n_dims
    bsize = args.training.batch_size
    data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
    task_sampler = get_task_sampler(
        args.training.task,
        n_dims,
        bsize,
        num_tasks=args.training.num_tasks,
        **args.training.task_kwargs,
    )
    pbar = tqdm(range(starting_step, args.training.train_steps))

    num_training_examples = args.training.num_training_examples

    for i in pbar:
        data_sampler_args = {}
        task_sampler_args = {}

        if "sparse" in args.training.task:
            task_sampler_args["valid_coords"] = curriculum.n_dims_truncated
        if num_training_examples is not None:
            assert num_training_examples >= bsize
            seeds = sample_seeds(num_training_examples, bsize)
            data_sampler_args["seeds"] = seeds
            task_sampler_args["seeds"] = [s + 1 for s in seeds]

        xs = data_sampler.sample_xs(
            curriculum.n_points,
            bsize,
            curriculum.n_dims_truncated,
            **data_sampler_args,
        )
        task = task_sampler(**task_sampler_args)
        ys = task.evaluate(xs)

        loss_func = task.get_training_metric()

        loss, output = train_step(model, xs.cuda(), ys.cuda(), optimizer, loss_func)

        point_wise_tags = list(range(curriculum.n_points))
        point_wise_loss_func = task.get_metric()
        point_wise_loss = point_wise_loss_func(output, ys.cuda()).mean(dim=0)

        baseline_loss = (
            sum(
                max(curriculum.n_dims_truncated - ii, 0)
                for ii in range(curriculum.n_points)
            )
            / curriculum.n_points
        )

        if i % args.wandb.log_every_steps == 0 and not args.test_run:
            wandb.log(
                {
                    "overall_loss": loss,
                    "excess_loss": loss / baseline_loss,
                    "pointwise/loss": dict(
                        zip(point_wise_tags, point_wise_loss.cpu().numpy())
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss}")
        if i % args.training.save_every_steps == 0 and not args.test_run:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
            }
            torch.save(training_state, state_path)

        if (
            args.training.keep_every_steps > 0
            and i % args.training.keep_every_steps == 0
            and not args.test_run
            and i > 0
        ):
            torch.save(model, os.path.join(args.out_dir, f"model_{i}.pt"))




def Gen_data_SS_innovation_noise_obs_noise_F_option_2(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0, alpha_Q=0.0, alpha_R=0.0, F_option=2):

    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
            A_Q = 0.0*torch.eye(input_dim, dtype=float);
        else:
            if F_option==2:
                U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F=torch.diag(torch.rand((input_dim,),dtype=float));
                F=torch.matmul(torch.matmul(U,Sigma_F), U.T)

            else:
                F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)

            Sigma_Q_sqrt=torch.sqrt(alpha_Q*torch.diag(torch.rand((input_dim,),dtype=float)))
            Q_U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
            A_Q=torch.matmul(Q_U, Sigma_Q_sqrt);
            Q=torch.matmul(A_Q,A_Q.T)


        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)
        w = torch.unsqueeze(w_t_m_1, dim=0);

        for k in range(int(chunk_size) - 1):
            innovation_noise=torch.matmul(A_Q,torch.randn((input_dim,1), dtype=float));
            w_t_m_1 = torch.matmul(F, w_t_m_1)+innovation_noise;
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)

        x = x_sigma * torch.randn((input_dim, chunk_size), dtype=float)
        noise_var=alpha_R*torch.rand((1,), dtype=float)
        obs_noise=torch.sqrt(noise_var)*torch.randn((chunk_size,), dtype=float)


        if d_curr < input_dim:
            x[(d_curr - 1):, :] = 0.0;
        y = torch.diag(torch.matmul(torch.squeeze(w, dim=-1), x))+obs_noise;
        inputs = torch.zeros((input_dim + 1, 2 * chunk_size + 2*input_dim+1))

        inputs[1:input_dim + 1, 0:input_dim] = F
        inputs[1:input_dim + 1, input_dim:2*input_dim] = Q
        inputs[0, 2*input_dim+1] = noise_var;
        inputs[1:input_dim + 1, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = x

        inputs[0, (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2] = y
        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    outputs_batch = torch.zeros((batch_size, 2 * chunk_size + 2*input_dim+1))
    outputs_batch[:, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = inputs_batch[:, 0,
                                                                       (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2]

    inputs_batch = torch.transpose(inputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)
    return inputs_batch, outputs_batch


def Gen_data_SS_innovation_noise_obs_noise_F_options(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0, alpha_Q=0.0, alpha_R=0.0, F_option=2):

    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
            A_Q = 0.0*torch.eye(input_dim, dtype=float);
        else:
            if F_option==2:
                U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F=torch.diag(torch.rand((input_dim,),dtype=float));
                F=torch.matmul(torch.matmul(U,Sigma_F), U.T)

            elif F_option==3:
                U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F = torch.diag(2.*torch.rand((input_dim,), dtype=float)-1.);
                F = torch.matmul(torch.matmul(U, Sigma_F), U.T)

            else:
                F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)

            Sigma_Q_sqrt=torch.sqrt(alpha_Q*torch.diag(torch.rand((input_dim,),dtype=float)))
            Q_U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
            A_Q=torch.matmul(Q_U, Sigma_Q_sqrt);
            Q=torch.matmul(A_Q,A_Q.T)


        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)
        w = torch.unsqueeze(w_t_m_1, dim=0);

        for k in range(int(chunk_size) - 1):
            innovation_noise=torch.matmul(A_Q,torch.randn((input_dim,1), dtype=float));
            w_t_m_1 = torch.matmul(F, w_t_m_1)+innovation_noise;
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)

        x = x_sigma * torch.randn((input_dim, chunk_size), dtype=float)
        noise_var=alpha_R*torch.rand((1,), dtype=float)
        obs_noise=torch.sqrt(noise_var)*torch.randn((chunk_size,), dtype=float)


        if d_curr < input_dim:
            x[(d_curr - 1):, :] = 0.0;
        y = torch.diag(torch.matmul(torch.squeeze(w, dim=-1), x))+obs_noise;
        inputs = torch.zeros((input_dim + 1, 2 * chunk_size + 2*input_dim+1))

        inputs[1:input_dim + 1, 0:input_dim] = F
        inputs[1:input_dim + 1, input_dim:2*input_dim] = Q
        inputs[0, 2*input_dim+1] = noise_var;
        inputs[1:input_dim + 1, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = x

        inputs[0, (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2] = y
        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    outputs_batch = torch.zeros((batch_size, 2 * chunk_size + 2*input_dim+1))
    outputs_batch[:, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = inputs_batch[:, 0,
                                                                       (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2]

    inputs_batch = torch.transpose(inputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)
    return inputs_batch, outputs_batch

def Gen_data_SS_innovation_noise_obs_noise_F_options_non_scalar_y(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0, alpha_Q=0.0, alpha_R=0.0, F_option=2, y_dim=2):

    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
            A_Q = 0.0*torch.eye(input_dim, dtype=float);
        else:
            if F_option==2:
                U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F=torch.diag(torch.rand((input_dim,),dtype=float));
                F=torch.matmul(torch.matmul(U,Sigma_F), U.T)

            elif F_option==3:
                U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F = torch.diag(2.*torch.rand((input_dim,), dtype=float)-1.);
                F = torch.matmul(torch.matmul(U, Sigma_F), U.T)

            else:
                F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)

            Sigma_Q_sqrt=torch.sqrt(alpha_Q*torch.diag(torch.rand((input_dim,),dtype=float)))
            Q_U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
            A_Q=torch.matmul(Q_U, Sigma_Q_sqrt);
            Q=torch.matmul(A_Q,A_Q.T)

        noise_var = alpha_R * torch.rand((y_dim,), dtype=float)
        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)

        x_t = x_sigma * torch.randn((y_dim, input_dim), dtype=float)
        if d_curr < input_dim:
            x_t[:, (d_curr - 1):] = 0.0;
        obs_noise = torch.matmul(torch.diag(torch.sqrt(noise_var)) , torch.randn((y_dim,1), dtype=float))
        y = torch.matmul(x_t, w_t_m_1)+obs_noise
        x = torch.reshape(x_t, (y_dim * input_dim, 1))

        w = torch.unsqueeze(w_t_m_1, dim=0);



        for k in range(int(chunk_size) - 1):
            innovation_noise=torch.matmul(A_Q,torch.randn((input_dim,1), dtype=float));
            w_t_m_1 = torch.matmul(F, w_t_m_1)+innovation_noise;
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)
            x_t = x_sigma * torch.randn((y_dim, input_dim), dtype=float)
            if d_curr < input_dim:
                x_t[:, (d_curr - 1):] = 0.0;
            obs_noise = torch.matmul(torch.diag(torch.sqrt(noise_var)), torch.randn((y_dim, 1), dtype=float))
            y_t = torch.matmul(x_t, w_t_m_1) + obs_noise
            x=torch.concat((x,torch.reshape(x_t, (y_dim * input_dim, 1))), dim=-1)
            y= torch.concat((y,y_t), dim=-1)





        inputs = torch.zeros(((input_dim + 1)*y_dim, 2 * chunk_size + 2*input_dim+1))

        inputs[y_dim:input_dim + y_dim, 0:input_dim] = F.T
        inputs[y_dim:input_dim + y_dim, input_dim:2*input_dim] = Q.T
        inputs[0:y_dim, 2*input_dim] = noise_var;
        inputs[y_dim:y_dim*(input_dim + 1), (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = x

        inputs[0:y_dim, (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2] = y
        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    outputs_batch = torch.zeros((batch_size,y_dim, 2 * chunk_size + 2*input_dim+1))
    outputs_batch[:, :, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = inputs_batch[:, 0:y_dim,
                                                                       (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2]

    inputs_batch = torch.transpose(inputs_batch, 1, 2);
    outputs_batch = torch.transpose(outputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)

    return inputs_batch, outputs_batch


def Gen_data_SS_innovation_noise_obs_noise_F_options_discard_noise_stats(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0, alpha_Q=0.0, alpha_R=0.0, F_option=2, discard_noise_stats=True):

    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
            A_Q = 0.0*torch.eye(input_dim, dtype=float);
        else:
            if F_option==2:
                U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F=torch.diag(torch.rand((input_dim,),dtype=float));
                F=torch.matmul(torch.matmul(U,Sigma_F), U.T)

            elif F_option==3:
                U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F = torch.diag(2.*torch.rand((input_dim,), dtype=float)-1.);
                F = torch.matmul(torch.matmul(U, Sigma_F), U.T)

            else:
                F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)

            Sigma_Q_sqrt=torch.sqrt(alpha_Q*torch.diag(torch.rand((input_dim,),dtype=float)))
            Q_U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
            A_Q=torch.matmul(Q_U, Sigma_Q_sqrt);
            Q=torch.matmul(A_Q,A_Q.T)


        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)
        w = torch.unsqueeze(w_t_m_1, dim=0);

        for k in range(int(chunk_size) - 1):
            innovation_noise=torch.matmul(A_Q,torch.randn((input_dim,1), dtype=float));
            w_t_m_1 = torch.matmul(F, w_t_m_1)+innovation_noise;
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)

        x = x_sigma * torch.randn((input_dim, chunk_size), dtype=float)
        noise_var=alpha_R*torch.rand((1,), dtype=float)
        obs_noise=torch.sqrt(noise_var)*torch.randn((chunk_size,), dtype=float)


        if d_curr < input_dim:
            x[(d_curr - 1):, :] = 0.0;
        y = torch.diag(torch.matmul(torch.squeeze(w, dim=-1), x))+obs_noise;
        if discard_noise_stats:
            inputs = torch.zeros((input_dim + 1, 2 * chunk_size + input_dim))

            inputs[1:input_dim + 1, 0:input_dim] = F
            inputs[1:input_dim + 1, (input_dim):(2 * chunk_size + input_dim):2] = x

            inputs[0, (input_dim + 1):(2 * chunk_size + input_dim):2] = y
        else:
            inputs = torch.zeros((input_dim + 1, 2 * chunk_size + 2*input_dim+1))
            inputs[1:input_dim + 1, 0:input_dim] = F
            inputs[1:input_dim + 1, input_dim:2 * input_dim] = Q
            inputs[0, 2 * input_dim + 1] = noise_var;
            inputs[1:input_dim + 1, (2 * input_dim + 1):(2 * chunk_size + 2 * input_dim + 1):2] = x

            inputs[0, (2 * input_dim + 2):(2 * chunk_size + 2 * input_dim + 1):2] = y

        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    if discard_noise_stats:
        outputs_batch = torch.zeros((batch_size, 2 * chunk_size + input_dim))
        outputs_batch[:, (input_dim):(2 * chunk_size + input_dim ):2] = inputs_batch[:, 0,
                                                                                       (input_dim + 1):(
                                                                                                   2 * chunk_size + input_dim):2]
    else:
        outputs_batch = torch.zeros((batch_size, 2 * chunk_size + 2*input_dim+1))
        outputs_batch[:, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = inputs_batch[:, 0,
                                                                           (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2]

    inputs_batch = torch.transpose(inputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)
    return inputs_batch, outputs_batch



def Gen_data_SS_innovation_noise_obs_noise_F_options_discard_all_stats(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0, alpha_Q=0.0, alpha_R=0.0, F_option=2, discard_stats=True):

    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
            A_Q = 0.0*torch.eye(input_dim, dtype=float);
        else:
            if F_option==2:
                U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F=torch.diag(torch.rand((input_dim,),dtype=float));
                F=torch.matmul(torch.matmul(U,Sigma_F), U.T)

            elif F_option==3:
                U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F = torch.diag(2.*torch.rand((input_dim,), dtype=float)-1.);
                F = torch.matmul(torch.matmul(U, Sigma_F), U.T)

            else:
                F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)

            Sigma_Q_sqrt=torch.sqrt(alpha_Q*torch.diag(torch.rand((input_dim,),dtype=float)))
            Q_U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
            A_Q=torch.matmul(Q_U, Sigma_Q_sqrt);
            Q=torch.matmul(A_Q,A_Q.T)


        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)
        w = torch.unsqueeze(w_t_m_1, dim=0);

        for k in range(int(chunk_size) - 1):
            innovation_noise=torch.matmul(A_Q,torch.randn((input_dim,1), dtype=float));
            w_t_m_1 = torch.matmul(F, w_t_m_1)+innovation_noise;
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)

        x = x_sigma * torch.randn((input_dim, chunk_size), dtype=float)
        noise_var=alpha_R*torch.rand((1,), dtype=float)
        obs_noise=torch.sqrt(noise_var)*torch.randn((chunk_size,), dtype=float)


        if d_curr < input_dim:
            x[(d_curr - 1):, :] = 0.0;
        y = torch.diag(torch.matmul(torch.squeeze(w, dim=-1), x))+obs_noise;
        if discard_stats:
            inputs = torch.zeros((input_dim + 1, 2 * chunk_size))

            # inputs[1:input_dim + 1, 0:input_dim] = F
            inputs[1:input_dim + 1, 0:2 * chunk_size:2] = x

            inputs[0, 1:2 * chunk_size:2] = y
        else:
            inputs = torch.zeros((input_dim + 1, 2 * chunk_size + 2*input_dim+1))
            inputs[1:input_dim + 1, 0:input_dim] = F
            inputs[1:input_dim + 1, input_dim:2 * input_dim] = Q
            inputs[0, 2 * input_dim + 1] = noise_var;
            inputs[1:input_dim + 1, (2 * input_dim + 1):(2 * chunk_size + 2 * input_dim + 1):2] = x

            inputs[0, (2 * input_dim + 2):(2 * chunk_size + 2 * input_dim + 1):2] = y

        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)

    if discard_stats:
        outputs_batch = torch.zeros((batch_size, 2 * chunk_size))
        outputs_batch[:, (0):(2 * chunk_size):2] = inputs_batch[:, 0, 1:2 * chunk_size:2]
    else:
        outputs_batch = torch.zeros((batch_size, 2 * chunk_size + 2*input_dim+1))
        outputs_batch[:, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = inputs_batch[:, 0,
                                                                           (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2]

    inputs_batch = torch.transpose(inputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)
    return inputs_batch, outputs_batch





def Gen_data_One_Step(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0, alpha_Q=0.0, alpha_R=0.0, F_option=2, y_dim=2, discard=True, discard_mode='All'):

    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
            A_Q = 0.0*torch.eye(input_dim, dtype=float);
        else:
            if F_option==2:
                U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F=torch.diag(torch.rand((input_dim,),dtype=float));
                F=torch.matmul(torch.matmul(U,Sigma_F), U.T)

            elif F_option==3:
                U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F = torch.diag(2.*torch.rand((input_dim,), dtype=float)-1.);
                F = torch.matmul(torch.matmul(U, Sigma_F), U.T)

            else:
                F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)

            Sigma_Q_sqrt=torch.sqrt(alpha_Q*torch.diag(torch.rand((input_dim,),dtype=float)))
            Q_U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
            A_Q=torch.matmul(Q_U, Sigma_Q_sqrt);
            Q=torch.matmul(A_Q,A_Q.T)

        noise_var = alpha_R * torch.rand((y_dim,), dtype=float)
        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)

        x_t = x_sigma * torch.randn((y_dim, input_dim), dtype=float)
        if d_curr < input_dim:
            x_t[:, (d_curr - 1):] = 0.0;
        obs_noise = torch.matmul(torch.diag(torch.sqrt(noise_var)) , torch.randn((y_dim,1), dtype=float))
        y = torch.matmul(x_t, w_t_m_1)+obs_noise
        x = torch.reshape(x_t, (y_dim * input_dim, 1))

        w = torch.unsqueeze(w_t_m_1, dim=0);



        for k in range(int(chunk_size) - 1):
            innovation_noise=torch.matmul(A_Q,torch.randn((input_dim,1), dtype=float));
            w_t_m_1 = torch.matmul(F, w_t_m_1)+innovation_noise;
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)
            x_t = x_sigma * torch.randn((y_dim, input_dim), dtype=float)
            if d_curr < input_dim:
                x_t[:, (d_curr - 1):] = 0.0;
            obs_noise = torch.matmul(torch.diag(torch.sqrt(noise_var)), torch.randn((y_dim, 1), dtype=float))
            y_t = torch.matmul(x_t, w_t_m_1) + obs_noise
            x=torch.concat((x,torch.reshape(x_t, (y_dim * input_dim, 1))), dim=-1)
            y= torch.concat((y,y_t), dim=-1)




        if not discard:
            inputs = torch.zeros(((input_dim + 1)*y_dim, 2 * chunk_size + 2*input_dim+1))

            inputs[y_dim:input_dim + y_dim, 0:input_dim] = F.T
            inputs[y_dim:input_dim + y_dim, input_dim:2*input_dim] = Q.T
            inputs[0:y_dim, 2*input_dim] = noise_var;
            inputs[y_dim:y_dim*(input_dim + 1), (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = x

            inputs[0:y_dim, (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2] = y

        elif discard and discard_mode=='All':
            inputs = torch.zeros(((input_dim + 1) * y_dim, 2 * chunk_size))

            inputs[y_dim:y_dim * (input_dim + 1), 0:2 * chunk_size:2] = x

            inputs[0:y_dim, 1:2 * chunk_size:2] = y

        elif discard and discard_mode=='Noise':
            inputs = torch.zeros(((input_dim + 1) * y_dim, 2 * chunk_size +input_dim))

            inputs[y_dim:input_dim + y_dim, 0:input_dim] = F.T

            inputs[y_dim:y_dim * (input_dim + 1), input_dim :2*chunk_size+input_dim:2] = x

            inputs[0:y_dim, input_dim+1:2*chunk_size+input_dim:2] = y

        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)


    if not discard:
        outputs_batch = torch.zeros((batch_size,y_dim, 2 * chunk_size + 2*input_dim+1))
        outputs_batch[:, :, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = inputs_batch[:, 0:y_dim,
                                                                           (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2]
    elif discard and discard_mode=='All':
        outputs_batch = torch.zeros((batch_size, y_dim, 2 * chunk_size))
        outputs_batch[:, :, 0:2 * chunk_size:2] = inputs_batch[:, 0:y_dim,1:2 * chunk_size:2]

    elif discard and discard_mode == 'Noise':
        outputs_batch = torch.zeros((batch_size, y_dim, 2 * chunk_size +input_dim))
        outputs_batch[:, :, input_dim:2 * chunk_size+input_dim:2] = inputs_batch[:, 0:y_dim, input_dim + 1:2 * chunk_size +input_dim:2]


    inputs_batch = torch.transpose(inputs_batch, 1, 2);
    outputs_batch = torch.transpose(outputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)

    return inputs_batch, outputs_batch


def Gen_data_One_Step_with_Control(device='cuda', batch_size=64, input_dim=8, chunk_size=40, w_sigma=1, x_sigma=1, d_curr=8, Dynamic=True, alpha_F=0.0, alpha_Q=0.0, alpha_R=0.0, F_option=2, y_dim=2, discard=True, discard_mode='All', control='True'):

    for i in range(batch_size):

        # U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        # Sigma=torch.diag(0.075*torch.randn((input_dim, ), dtype=float)+0.85)
        # F=torch.matmul(torch.matmul(U, Sigma), U.T)
        # F=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
        if not Dynamic:
            F = torch.eye(input_dim, dtype=float)
            A_Q = 0.0*torch.eye(input_dim, dtype=float);
        else:
            if F_option==2:
                U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F=torch.diag(torch.rand((input_dim,),dtype=float));
                F=torch.matmul(torch.matmul(U,Sigma_F), U.T)

            elif F_option==3:
                U = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_F = torch.diag(2.*torch.rand((input_dim,), dtype=float)-1.);
                F = torch.matmul(torch.matmul(U, Sigma_F), U.T)

            else:
                F = alpha_F*torch.tensor(ortho_group.rvs(input_dim), dtype=float)+(1-alpha_F)*torch.eye(input_dim, dtype=float)



            if control:
                U_B = torch.tensor(ortho_group.rvs(input_dim), dtype=float)
                Sigma_B = torch.diag(2. * torch.rand((input_dim,), dtype=float) - 1.);
                B = torch.matmul(torch.matmul(U_B, Sigma_B), U_B.T)



            Sigma_Q_sqrt=torch.sqrt(alpha_Q*torch.diag(torch.rand((input_dim,),dtype=float)))
            Q_U=torch.tensor(ortho_group.rvs(input_dim), dtype=float)
            A_Q=torch.matmul(Q_U, Sigma_Q_sqrt);
            Q=torch.matmul(A_Q,A_Q.T)


        noise_var = alpha_R * torch.rand((y_dim,), dtype=float)

        u_t = torch.zeros((input_dim, 1), dtype=float)

        w_t_m_1 = w_sigma * torch.randn((input_dim, 1), dtype=float)+u_t;

        x_t = x_sigma * torch.randn((y_dim, input_dim), dtype=float)
        if d_curr < input_dim:
            x_t[:, (d_curr - 1):] = 0.0;


        obs_noise = torch.matmul(torch.diag(torch.sqrt(noise_var)) , torch.randn((y_dim,1), dtype=float))
        y = torch.matmul(x_t, w_t_m_1)+obs_noise
        x = torch.reshape(x_t, (y_dim * input_dim, 1))
        u = torch.reshape(u_t, (input_dim, 1));

        w = torch.unsqueeze(w_t_m_1, dim=0);



        for k in range(int(chunk_size) - 1):
            innovation_noise=torch.matmul(A_Q,torch.randn((input_dim,1), dtype=float));
            if control:
                u_t=torch.randn((input_dim,1), dtype=float)
                u_t=u_t/torch.linalg.vector_norm(torch.squeeze(u_t))
                control_part=torch.matmul(B, u_t);
            else:
                u_t = torch.zeros((input_dim, 1), dtype=float)
                control_part=0.*u_t;
            w_t_m_1 = torch.matmul(F, w_t_m_1)+innovation_noise+control_part;
            w = torch.concat((w, torch.unsqueeze(w_t_m_1, dim=0)), dim=0)
            x_t = x_sigma * torch.randn((y_dim, input_dim), dtype=float)
            if d_curr < input_dim:
                x_t[:, (d_curr - 1):] = 0.0;
            obs_noise = torch.matmul(torch.diag(torch.sqrt(noise_var)), torch.randn((y_dim, 1), dtype=float))
            y_t = torch.matmul(x_t, w_t_m_1) + obs_noise
            x=torch.concat((x,torch.reshape(x_t, (y_dim * input_dim, 1))), dim=-1)
            u = torch.concat((u, torch.reshape(u_t, (input_dim, 1))), dim=-1)
            y= torch.concat((y,y_t), dim=-1)



        if control:
            if not discard:
                inputs = torch.zeros(((input_dim + 1)*y_dim, 3 * chunk_size + 3*input_dim+1))

                inputs[y_dim:input_dim + y_dim, 0:input_dim] = F.T
                inputs[y_dim:input_dim + y_dim, input_dim:2*input_dim] = Q.T
                inputs[0:y_dim, 2*input_dim] = noise_var;
                inputs[y_dim:input_dim + y_dim, 2*input_dim+1:3*input_dim+1] = B.T

                inputs[y_dim:y_dim*(input_dim + 1), (3*input_dim+1):(3 * chunk_size + 3*input_dim+1):3] = x
                inputs[y_dim:input_dim + y_dim, (3 * input_dim + 2):(3 * chunk_size + 3 * input_dim + 1):3] = u
                inputs[0:y_dim, (3*input_dim + 3):(3 * chunk_size + 3*input_dim+1):3] = y

            elif discard and discard_mode=='All':
                inputs = torch.zeros(((input_dim + 1) * y_dim, 3 * chunk_size))

                inputs[y_dim:y_dim * (input_dim + 1), 0:3 * chunk_size:3] = x
                inputs[y_dim:input_dim + y_dim, 1:3 * chunk_size:3] = u

                inputs[0:y_dim, 2:3 * chunk_size:3] = y

            elif discard and discard_mode=='Noise':
                inputs = torch.zeros(((input_dim + 1) * y_dim, 3 * chunk_size +2*input_dim))

                inputs[y_dim:input_dim + y_dim, 0:input_dim] = F.T
                inputs[y_dim:input_dim + y_dim, input_dim:2*input_dim] = B.T

                inputs[y_dim:y_dim * (input_dim + 1), 2*input_dim :3*chunk_size+2*input_dim:3] = x

                inputs[0:y_dim, 2*input_dim+1:3*chunk_size+2*input_dim:3] = u
                inputs[0:y_dim, 2 * input_dim + 2:3 * chunk_size + 2 * input_dim:3] = y

        else:
            if not discard:
                inputs = torch.zeros(((input_dim + 1) * y_dim, 2 * chunk_size + 2 * input_dim + 1))

                inputs[y_dim:input_dim + y_dim, 0:input_dim] = F.T
                inputs[y_dim:input_dim + y_dim, input_dim:2 * input_dim] = Q.T
                inputs[0:y_dim, 2 * input_dim] = noise_var;
                inputs[y_dim:y_dim * (input_dim + 1), (2 * input_dim + 1):(2 * chunk_size + 2 * input_dim + 1):2] = x

                inputs[0:y_dim, (2 * input_dim + 2):(2 * chunk_size + 2 * input_dim + 1):2] = y

            elif discard and discard_mode == 'All':
                inputs = torch.zeros(((input_dim + 1) * y_dim, 2 * chunk_size))

                inputs[y_dim:y_dim * (input_dim + 1), 0:2 * chunk_size:2] = x

                inputs[0:y_dim, 1:2 * chunk_size:2] = y

            elif discard and discard_mode == 'Noise':
                inputs = torch.zeros(((input_dim + 1) * y_dim, 2 * chunk_size + input_dim))

                inputs[y_dim:input_dim + y_dim, 0:input_dim] = F.T

                inputs[y_dim:y_dim * (input_dim + 1), input_dim:2 * chunk_size + input_dim:2] = x

                inputs[0:y_dim, input_dim + 1:2 * chunk_size + input_dim:2] = y

        if i == 0:
            inputs_batch = torch.unsqueeze(inputs, dim=0)
        else:
            inputs_batch = torch.concat((inputs_batch, torch.unsqueeze(inputs, dim=0)), dim=0)


    if control:
        if not discard:
            outputs_batch = torch.zeros((batch_size, y_dim, 3 * chunk_size + 3 * input_dim + 1))
            outputs_batch[:, :, (3 * input_dim + 1):(3 * chunk_size + 3 * input_dim + 1):3] = inputs_batch[:, 0:y_dim,
                                                                                              (3 * input_dim + 3):(
                                                                                                          3 * chunk_size + 3 * input_dim + 1):3]
        elif discard and discard_mode == 'All':
            outputs_batch = torch.zeros((batch_size, y_dim, 3 * chunk_size))
            outputs_batch[:, :, 0:3 * chunk_size:3] = inputs_batch[:, 0:y_dim, 2:3 * chunk_size:3]

        elif discard and discard_mode == 'Noise':
            outputs_batch = torch.zeros((batch_size, y_dim, 3 * chunk_size + 2*input_dim))
            outputs_batch[:, :, 2*input_dim:3 * chunk_size + 2*input_dim:3] = inputs_batch[:, 0:y_dim,
                                                                          2*input_dim + 2:3 * chunk_size + 2*input_dim:3]
    else:
        if not discard:
            outputs_batch = torch.zeros((batch_size,y_dim, 2 * chunk_size + 2*input_dim+1))
            outputs_batch[:, :, (2*input_dim+1):(2 * chunk_size + 2*input_dim+1):2] = inputs_batch[:, 0:y_dim,
                                                                               (2*input_dim + 2):(2 * chunk_size + 2*input_dim+1):2]
        elif discard and discard_mode=='All':
            outputs_batch = torch.zeros((batch_size, y_dim, 2 * chunk_size))
            outputs_batch[:, :, 0:2 * chunk_size:2] = inputs_batch[:, 0:y_dim,1:2 * chunk_size:2]

        elif discard and discard_mode == 'Noise':
            outputs_batch = torch.zeros((batch_size, y_dim, 2 * chunk_size +input_dim))
            outputs_batch[:, :, input_dim:2 * chunk_size+input_dim:2] = inputs_batch[:, 0:y_dim, input_dim + 1:2 * chunk_size +input_dim:2]


    inputs_batch = torch.transpose(inputs_batch, 1, 2);
    outputs_batch = torch.transpose(outputs_batch, 1, 2);

    inputs_batch = inputs_batch.to(device)
    outputs_batch = outputs_batch.to(device)

    return inputs_batch, outputs_batch


def train_mine(model, args):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
    curriculum = Curriculum(args.training.curriculum)

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()

    n_dims = model.n_dims
    bsize = args.training.batch_size
    data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
    task_sampler = get_task_sampler(
        args.training.task,
        n_dims,
        bsize,
        num_tasks=args.training.num_tasks,
        **args.training.task_kwargs,
    )
    pbar = tqdm(range(starting_step, args.training.train_steps))

    num_training_examples = args.training.num_training_examples

    for i in pbar:
        # data_sampler_args = {}
        # task_sampler_args = {}

        # if "sparse" in args.training.task:
        #     task_sampler_args["valid_coords"] = curriculum.n_dims_truncated
        #
        # if num_training_examples is not None:
        #     assert num_training_examples >= bsize
        #     seeds = sample_seeds(num_training_examples, bsize)
        #     data_sampler_args["seeds"] = seeds
        #     task_sampler_args["seeds"] = [s + 1 for s in seeds]

        # xs = data_sampler.sample_xs(
        #     curriculum.n_points,
        #     bsize,
        #     curriculum.n_dims_truncated,
        #     **data_sampler_args,
        # )
        # task = task_sampler(**task_sampler_args)
        # ys = task.evaluate(xs)

        Inputs_batch, Outputs_batch=Gen_data(device='cuda', batch_size=bsize, eval=False, input_dim=model.n_dims-1, chunk_size=curriculum.n_points, w_sigma=1, x_sigma=1, d_curr=curriculum.n_dims_truncated)

        loss_func = mean_squared_error

        loss, output = train_step_mine(model, Inputs_batch, Outputs_batch, optimizer, loss_func)

        point_wise_tags = list(range(curriculum.n_points))
        point_wise_loss_func = squared_error
        point_wise_loss = point_wise_loss_func(output, Outputs_batch[:,::2]).mean(dim=0)

        baseline_loss = (
            sum(
                max(curriculum.n_dims_truncated - ii, 0)
                for ii in range(curriculum.n_points)
            )
            / curriculum.n_points
        )

        if i % args.wandb.log_every_steps == 0 and not args.test_run:
            wandb.log(
                {
                    "overall_loss": loss,
                    "excess_loss": loss / baseline_loss,
                    "pointwise/loss": dict(
                        zip(point_wise_tags, point_wise_loss.cpu().numpy())
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss}")
        if i % args.training.save_every_steps == 0 and not args.test_run:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
            }
            torch.save(training_state, state_path)

        if (
            args.training.keep_every_steps > 0
            and i % args.training.keep_every_steps == 0
            and not args.test_run
            and i > 0
        ):
            torch.save(model, os.path.join(args.out_dir, f"model_{i}.pt"))






def train_mine_SS(model, args):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
    curriculum = Curriculum(args.training.curriculum)

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()

    n_dims = model.n_dims
    bsize = args.training.batch_size
    data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
    task_sampler = get_task_sampler(
        args.training.task,
        n_dims,
        bsize,
        num_tasks=args.training.num_tasks,
        **args.training.task_kwargs,
    )
    pbar = tqdm(range(starting_step, args.training.train_steps))

    num_training_examples = args.training.num_training_examples

    for i in pbar:
        # data_sampler_args = {}
        # task_sampler_args = {}

        # if "sparse" in args.training.task:
        #     task_sampler_args["valid_coords"] = curriculum.n_dims_truncated
        #
        # if num_training_examples is not None:
        #     assert num_training_examples >= bsize
        #     seeds = sample_seeds(num_training_examples, bsize)
        #     data_sampler_args["seeds"] = seeds
        #     task_sampler_args["seeds"] = [s + 1 for s in seeds]

        # xs = data_sampler.sample_xs(
        #     curriculum.n_points,
        #     bsize,
        #     curriculum.n_dims_truncated,
        #     **data_sampler_args,
        # )
        # task = task_sampler(**task_sampler_args)
        # ys = task.evaluate(xs)

        Inputs_batch, Outputs_batch=Gen_data_SS(device='cuda', batch_size=bsize, input_dim=model.n_dims-1, chunk_size=curriculum.n_points, w_sigma=1, x_sigma=1, d_curr=curriculum.n_dims_truncated, Dynamic=True, alpha_F=curriculum.F_alpha)

        loss_func = mean_squared_error

        loss, output = train_step_mine_SS(model, Inputs_batch, Outputs_batch, optimizer, loss_func)

        point_wise_tags = list(range(curriculum.n_points))
        point_wise_loss_func = squared_error
        point_wise_loss = point_wise_loss_func(output, Outputs_batch[:,(model.n_dims-1)::2]).mean(dim=0)

        baseline_loss = (
            sum(
                max(curriculum.n_dims_truncated - ii, 0)
                for ii in range(curriculum.n_points)
            )
            / curriculum.n_points
        )

        if i % args.wandb.log_every_steps == 0 and not args.test_run:
            wandb.log(
                {
                    "overall_loss": loss,
                    "excess_loss": loss / baseline_loss,
                    "pointwise/loss": dict(
                        zip(point_wise_tags, point_wise_loss.cpu().numpy())
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                    "F_alpha": curriculum.F_alpha,
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss}")
        if i % args.training.save_every_steps == 0 and not args.test_run:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
            }
            torch.save(training_state, state_path)

        if (
            args.training.keep_every_steps > 0
            and i % args.training.keep_every_steps == 0
            and not args.test_run
            and i > 0
        ):
            torch.save(model, os.path.join(args.out_dir, f"model_{i}.pt"))


def train_mine_SS_innovation_noise(model, args):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
    curriculum = Curriculum(args.training.curriculum)

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()

    n_dims = model.n_dims
    bsize = args.training.batch_size
    data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
    task_sampler = get_task_sampler(
        args.training.task,
        n_dims,
        bsize,
        num_tasks=args.training.num_tasks,
        **args.training.task_kwargs,
    )
    pbar = tqdm(range(starting_step, args.training.train_steps))

    num_training_examples = args.training.num_training_examples

    for i in pbar:
        # data_sampler_args = {}
        # task_sampler_args = {}

        # if "sparse" in args.training.task:
        #     task_sampler_args["valid_coords"] = curriculum.n_dims_truncated
        #
        # if num_training_examples is not None:
        #     assert num_training_examples >= bsize
        #     seeds = sample_seeds(num_training_examples, bsize)
        #     data_sampler_args["seeds"] = seeds
        #     task_sampler_args["seeds"] = [s + 1 for s in seeds]

        # xs = data_sampler.sample_xs(
        #     curriculum.n_points,
        #     bsize,
        #     curriculum.n_dims_truncated,
        #     **data_sampler_args,
        # )
        # task = task_sampler(**task_sampler_args)
        # ys = task.evaluate(xs)

        Inputs_batch, Outputs_batch=Gen_data_SS_innovation_noise(device='cuda', batch_size=bsize, input_dim=model.n_dims-1, chunk_size=curriculum.n_points, w_sigma=1, x_sigma=1, d_curr=curriculum.n_dims_truncated, Dynamic=True, alpha_F=curriculum.F_alpha, alpha_Q=curriculum.Q_alpha)

        loss_func = mean_squared_error

        loss, output = train_step_mine_SS_innovation_noise(model, Inputs_batch, Outputs_batch, optimizer, loss_func)

        point_wise_tags = list(range(curriculum.n_points))
        point_wise_loss_func = squared_error
        point_wise_loss = point_wise_loss_func(output, Outputs_batch[:,2*(model.n_dims-1)::2]).mean(dim=0)

        baseline_loss = (
            sum(
                max(curriculum.n_dims_truncated - ii, 0)
                for ii in range(curriculum.n_points)
            )
            / curriculum.n_points
        )

        if i % args.wandb.log_every_steps == 0 and not args.test_run:
            wandb.log(
                {
                    "overall_loss": loss,
                    "excess_loss": loss / baseline_loss,
                    "pointwise/loss": dict(
                        zip(point_wise_tags, point_wise_loss.cpu().numpy())
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                    "F_alpha": curriculum.F_alpha,
                    "Q_alpha":curriculum.Q_alpha,
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss}")
        if i % args.training.save_every_steps == 0 and not args.test_run:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
            }
            torch.save(training_state, state_path)

        if (
            args.training.keep_every_steps > 0
            and i % args.training.keep_every_steps == 0
            and not args.test_run
            and i > 0
        ):
            torch.save(model, os.path.join(args.out_dir, f"model_{i}.pt"))


def train_mine_SS_innovation_noise_obs_noise(model, args):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
    curriculum = Curriculum(args.training.curriculum)

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()

    n_dims = model.n_dims
    bsize = args.training.batch_size
    data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
    task_sampler = get_task_sampler(
        args.training.task,
        n_dims,
        bsize,
        num_tasks=args.training.num_tasks,
        **args.training.task_kwargs,
    )
    pbar = tqdm(range(starting_step, args.training.train_steps))

    num_training_examples = args.training.num_training_examples

    for i in pbar:
        # data_sampler_args = {}
        # task_sampler_args = {}

        # if "sparse" in args.training.task:
        #     task_sampler_args["valid_coords"] = curriculum.n_dims_truncated
        #
        # if num_training_examples is not None:
        #     assert num_training_examples >= bsize
        #     seeds = sample_seeds(num_training_examples, bsize)
        #     data_sampler_args["seeds"] = seeds
        #     task_sampler_args["seeds"] = [s + 1 for s in seeds]

        # xs = data_sampler.sample_xs(
        #     curriculum.n_points,
        #     bsize,
        #     curriculum.n_dims_truncated,
        #     **data_sampler_args,
        # )
        # task = task_sampler(**task_sampler_args)
        # ys = task.evaluate(xs)

        Inputs_batch, Outputs_batch=Gen_data_SS_innovation_noise_obs_noise(device='cuda', batch_size=bsize, input_dim=model.n_dims-1, chunk_size=curriculum.n_points, w_sigma=1, x_sigma=1, d_curr=curriculum.n_dims_truncated, Dynamic=True, alpha_F=curriculum.F_alpha, alpha_Q=curriculum.Q_alpha, alpha_R=curriculum.R_alpha)

        loss_func = mean_squared_error

        loss, output = train_step_mine_SS_innovation_noise_obs_noise(model, Inputs_batch, Outputs_batch, optimizer, loss_func)

        point_wise_tags = list(range(curriculum.n_points))
        point_wise_loss_func = squared_error
        point_wise_loss = point_wise_loss_func(output, Outputs_batch[:,2*(model.n_dims-1)+1::2]).mean(dim=0)

        baseline_loss = (
            sum(
                max(curriculum.n_dims_truncated - ii, 0)
                for ii in range(curriculum.n_points)
            )
            / curriculum.n_points
        )

        if i % args.wandb.log_every_steps == 0 and not args.test_run:
            wandb.log(
                {
                    "overall_loss": loss,
                    "excess_loss": loss / baseline_loss,
                    "pointwise/loss": dict(
                        zip(point_wise_tags, point_wise_loss.cpu().numpy())
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                    "F_alpha": curriculum.F_alpha,
                    "Q_alpha":curriculum.Q_alpha,
                    "R_alpha":curriculum.R_alpha
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss}")
        if i % args.training.save_every_steps == 0 and not args.test_run:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
            }
            torch.save(training_state, state_path)

        if (
            args.training.keep_every_steps > 0
            and i % args.training.keep_every_steps == 0
            and not args.test_run
            and i > 0
        ):
            torch.save(model, os.path.join(args.out_dir, f"model_{i}.pt"))



def train_mine_SS_innovation_noise_obs_noise_F_options(model, args, option=2):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
    curriculum = Curriculum(args.training.curriculum)

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()

    n_dims = model.n_dims
    bsize = args.training.batch_size
    data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
    task_sampler = get_task_sampler(
        args.training.task,
        n_dims,
        bsize,
        num_tasks=args.training.num_tasks,
        **args.training.task_kwargs,
    )
    pbar = tqdm(range(starting_step, args.training.train_steps))

    num_training_examples = args.training.num_training_examples

    for i in pbar:
        # data_sampler_args = {}
        # task_sampler_args = {}

        # if "sparse" in args.training.task:
        #     task_sampler_args["valid_coords"] = curriculum.n_dims_truncated
        #
        # if num_training_examples is not None:
        #     assert num_training_examples >= bsize
        #     seeds = sample_seeds(num_training_examples, bsize)
        #     data_sampler_args["seeds"] = seeds
        #     task_sampler_args["seeds"] = [s + 1 for s in seeds]

        # xs = data_sampler.sample_xs(
        #     curriculum.n_points,
        #     bsize,
        #     curriculum.n_dims_truncated,
        #     **data_sampler_args,
        # )
        # task = task_sampler(**task_sampler_args)
        # ys = task.evaluate(xs)

        Inputs_batch, Outputs_batch=Gen_data_SS_innovation_noise_obs_noise_F_options(device='cuda', batch_size=bsize, input_dim=model.n_dims-1, chunk_size=curriculum.n_points, w_sigma=1, x_sigma=1, d_curr=curriculum.n_dims_truncated, Dynamic=True, alpha_F=curriculum.F_alpha, alpha_Q=curriculum.Q_alpha, alpha_R=curriculum.R_alpha, F_option=option)

        loss_func = mean_squared_error

        loss, output = train_step_mine_SS_innovation_noise_obs_noise(model, Inputs_batch, Outputs_batch, optimizer, loss_func)

        point_wise_tags = list(range(curriculum.n_points))
        point_wise_loss_func = squared_error
        point_wise_loss = point_wise_loss_func(output, Outputs_batch[:,2*(model.n_dims-1)+1::2]).mean(dim=0)

        baseline_loss = (
            sum(
                max(curriculum.n_dims_truncated - ii, 0)
                for ii in range(curriculum.n_points)
            )
            / curriculum.n_points
        )

        if i % args.wandb.log_every_steps == 0 and not args.test_run:
            wandb.log(
                {
                    "overall_loss": loss,
                    "excess_loss": loss / baseline_loss,
                    "pointwise/loss": dict(
                        zip(point_wise_tags, point_wise_loss.cpu().numpy())
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                    "F_alpha": curriculum.F_alpha,
                    "Q_alpha":curriculum.Q_alpha,
                    "R_alpha":curriculum.R_alpha
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss}")
        if i % args.training.save_every_steps == 0 and not args.test_run:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
            }
            torch.save(training_state, state_path)

        if (
            args.training.keep_every_steps > 0
            and i % args.training.keep_every_steps == 0
            and not args.test_run
            and i > 0
        ):
            torch.save(model, os.path.join(args.out_dir, f"model_{i}.pt"))

def train_mine_SS_innovation_noise_obs_noise_F_options_non_scalar_y(model, args, option=2):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
    curriculum = Curriculum(args.training.curriculum)

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()

    n_dims = model.n_dims
    bsize = args.training.batch_size
    data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
    task_sampler = get_task_sampler(
        args.training.task,
        n_dims,
        bsize,
        num_tasks=args.training.num_tasks,
        **args.training.task_kwargs,
    )
    pbar = tqdm(range(starting_step, args.training.train_steps))

    num_training_examples = args.training.num_training_examples

    for i in pbar:
        # data_sampler_args = {}
        # task_sampler_args = {}

        # if "sparse" in args.training.task:
        #     task_sampler_args["valid_coords"] = curriculum.n_dims_truncated
        #
        # if num_training_examples is not None:
        #     assert num_training_examples >= bsize
        #     seeds = sample_seeds(num_training_examples, bsize)
        #     data_sampler_args["seeds"] = seeds
        #     task_sampler_args["seeds"] = [s + 1 for s in seeds]

        # xs = data_sampler.sample_xs(
        #     curriculum.n_points,
        #     bsize,
        #     curriculum.n_dims_truncated,
        #     **data_sampler_args,
        # )
        # task = task_sampler(**task_sampler_args)
        # ys = task.evaluate(xs)

        Inputs_batch, Outputs_batch=Gen_data_SS_innovation_noise_obs_noise_F_options_non_scalar_y(device='cuda', batch_size=bsize, input_dim=int(n_dims/model.y_dim-1), chunk_size=curriculum.n_points, w_sigma=1, x_sigma=1, d_curr=curriculum.n_dims_truncated, Dynamic=True, alpha_F=curriculum.F_alpha, alpha_Q=curriculum.Q_alpha, alpha_R=curriculum.R_alpha, F_option=option, y_dim=2)


        loss_func = mean_squared_error_measurement

        loss, output = train_step_mine_SS_innovation_noise_obs_noise_non_scalar_y(model, Inputs_batch, Outputs_batch, optimizer, loss_func, y_dim=2)

        point_wise_tags = list(range(curriculum.n_points))
        point_wise_loss_func = squared_error
        point_wise_loss = point_wise_loss_func(output, Outputs_batch[:,2*(model.state_dim)+1::2,:]).mean(dim=0)

        baseline_loss = (
            sum(
                max(curriculum.n_dims_truncated - ii, 0)
                for ii in range(curriculum.n_points)
            )
            / curriculum.n_points
        )

        if i % args.wandb.log_every_steps == 0 and not args.test_run:
            wandb.log(
                {
                    "overall_loss": loss,
                    "excess_loss": loss / baseline_loss,
                    "pointwise/loss": dict(
                        zip(point_wise_tags, point_wise_loss.cpu().numpy())
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                    "F_alpha": curriculum.F_alpha,
                    "Q_alpha":curriculum.Q_alpha,
                    "R_alpha":curriculum.R_alpha
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss}")
        if i % args.training.save_every_steps == 0 and not args.test_run:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
            }
            torch.save(training_state, state_path)

        if (
            args.training.keep_every_steps > 0
            and i % args.training.keep_every_steps == 0
            and not args.test_run
            and i > 0
        ):
            torch.save(model, os.path.join(args.out_dir, f"model_{i}.pt"))


def train_mine_SS_innovation_noise_obs_noise_F_options_discard_noise_stats(model, args, option=2):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
    curriculum = Curriculum(args.training.curriculum)

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()

    n_dims = model.n_dims
    bsize = args.training.batch_size
    data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
    task_sampler = get_task_sampler(
        args.training.task,
        n_dims,
        bsize,
        num_tasks=args.training.num_tasks,
        **args.training.task_kwargs,
    )
    pbar = tqdm(range(starting_step, args.training.train_steps))

    num_training_examples = args.training.num_training_examples

    for i in pbar:
        # data_sampler_args = {}
        # task_sampler_args = {}

        # if "sparse" in args.training.task:
        #     task_sampler_args["valid_coords"] = curriculum.n_dims_truncated
        #
        # if num_training_examples is not None:
        #     assert num_training_examples >= bsize
        #     seeds = sample_seeds(num_training_examples, bsize)
        #     data_sampler_args["seeds"] = seeds
        #     task_sampler_args["seeds"] = [s + 1 for s in seeds]

        # xs = data_sampler.sample_xs(
        #     curriculum.n_points,
        #     bsize,
        #     curriculum.n_dims_truncated,
        #     **data_sampler_args,
        # )
        # task = task_sampler(**task_sampler_args)
        # ys = task.evaluate(xs)

        Inputs_batch, Outputs_batch=Gen_data_SS_innovation_noise_obs_noise_F_options_discard_noise_stats(device='cuda', batch_size=bsize, input_dim=model.n_dims-1, chunk_size=curriculum.n_points, w_sigma=1, x_sigma=1, d_curr=curriculum.n_dims_truncated, Dynamic=True, alpha_F=curriculum.F_alpha, alpha_Q=curriculum.Q_alpha, alpha_R=curriculum.R_alpha, F_option=option)

        loss_func = mean_squared_error

        loss, output = train_step_mine_SS(model, Inputs_batch, Outputs_batch, optimizer, loss_func)

        point_wise_tags = list(range(curriculum.n_points))
        point_wise_loss_func = squared_error
        point_wise_loss = point_wise_loss_func(output, Outputs_batch[:,(model.n_dims-1)::2]).mean(dim=0)

        baseline_loss = (
            sum(
                max(curriculum.n_dims_truncated - ii, 0)
                for ii in range(curriculum.n_points)
            )
            / curriculum.n_points
        )

        if i % args.wandb.log_every_steps == 0 and not args.test_run:
            wandb.log(
                {
                    "overall_loss": loss,
                    "excess_loss": loss / baseline_loss,
                    "pointwise/loss": dict(
                        zip(point_wise_tags, point_wise_loss.cpu().numpy())
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                    "F_alpha": curriculum.F_alpha,
                    "Q_alpha":curriculum.Q_alpha,
                    "R_alpha":curriculum.R_alpha
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss}")
        if i % args.training.save_every_steps == 0 and not args.test_run:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
            }
            torch.save(training_state, state_path)

        if (
            args.training.keep_every_steps > 0
            and i % args.training.keep_every_steps == 0
            and not args.test_run
            and i > 0
        ):
            torch.save(model, os.path.join(args.out_dir, f"model_{i}.pt"))



def train_mine_SS_innovation_noise_obs_noise_state_est_curr(model, args):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
    curriculum = Curriculum(args.training.curriculum)

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()

    n_dims = model.n_dims
    bsize = args.training.batch_size
    data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
    task_sampler = get_task_sampler(
        args.training.task,
        n_dims,
        bsize,
        num_tasks=args.training.num_tasks,
        **args.training.task_kwargs,
    )
    pbar = tqdm(range(starting_step, args.training.train_steps))

    num_training_examples = args.training.num_training_examples

    for i in pbar:
        # data_sampler_args = {}
        # task_sampler_args = {}

        # if "sparse" in args.training.task:
        #     task_sampler_args["valid_coords"] = curriculum.n_dims_truncated
        #
        # if num_training_examples is not None:
        #     assert num_training_examples >= bsize
        #     seeds = sample_seeds(num_training_examples, bsize)
        #     data_sampler_args["seeds"] = seeds
        #     task_sampler_args["seeds"] = [s + 1 for s in seeds]

        # xs = data_sampler.sample_xs(
        #     curriculum.n_points,
        #     bsize,
        #     curriculum.n_dims_truncated,
        #     **data_sampler_args,
        # )
        # task = task_sampler(**task_sampler_args)
        # ys = task.evaluate(xs)

        Inputs_batch, Outputs_batch=Gen_data_SS_innovation_noise_obs_noise_state_est_curr(device='cuda', batch_size=bsize, input_dim=model.n_dims-1, chunk_size=curriculum.n_points, w_sigma=1, x_sigma=1, d_curr=curriculum.n_dims_truncated, Dynamic=True, alpha_F=curriculum.F_alpha, alpha_Q=curriculum.Q_alpha, alpha_R=curriculum.R_alpha)

        loss_func = mean_squared_error_state

        loss, output = train_step_mine_SS_innovation_noise_obs_noise_state_est_curr(model, Inputs_batch, Outputs_batch, optimizer, loss_func)

        point_wise_tags = list(range(curriculum.n_points))
        point_wise_loss_func = squared_error
        point_wise_loss = point_wise_loss_func(output, Outputs_batch[:,2*(model.n_dims-1)+2::2,:]).mean(dim=0)

        baseline_loss = (
            sum(
                max(curriculum.n_dims_truncated - ii, 0)
                for ii in range(curriculum.n_points)
            )
            / curriculum.n_points
        )

        if i % args.wandb.log_every_steps == 0 and not args.test_run:
            wandb.log(
                {
                    "overall_loss": loss,
                    "excess_loss": loss / baseline_loss,
                    "pointwise/loss": dict(
                        zip(point_wise_tags, point_wise_loss.cpu().numpy())
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                    "F_alpha": curriculum.F_alpha,
                    "Q_alpha":curriculum.Q_alpha,
                    "R_alpha":curriculum.R_alpha
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss}")
        if i % args.training.save_every_steps == 0 and not args.test_run:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
            }
            torch.save(training_state, state_path)

        if (
            args.training.keep_every_steps > 0
            and i % args.training.keep_every_steps == 0
            and not args.test_run
            and i > 0
        ):
            torch.save(model, os.path.join(args.out_dir, f"model_{i}.pt"))




def train_mine_SS_innovation_noise_obs_noise_F_options_discard_all_stats(model, args, option=2):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
    curriculum = Curriculum(args.training.curriculum)

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()

    n_dims = model.n_dims
    bsize = args.training.batch_size
    data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
    task_sampler = get_task_sampler(
        args.training.task,
        n_dims,
        bsize,
        num_tasks=args.training.num_tasks,
        **args.training.task_kwargs,
    )
    pbar = tqdm(range(starting_step, args.training.train_steps))

    num_training_examples = args.training.num_training_examples

    for i in pbar:
        # data_sampler_args = {}
        # task_sampler_args = {}

        # if "sparse" in args.training.task:
        #     task_sampler_args["valid_coords"] = curriculum.n_dims_truncated
        #
        # if num_training_examples is not None:
        #     assert num_training_examples >= bsize
        #     seeds = sample_seeds(num_training_examples, bsize)
        #     data_sampler_args["seeds"] = seeds
        #     task_sampler_args["seeds"] = [s + 1 for s in seeds]

        # xs = data_sampler.sample_xs(
        #     curriculum.n_points,
        #     bsize,
        #     curriculum.n_dims_truncated,
        #     **data_sampler_args,
        # )
        # task = task_sampler(**task_sampler_args)
        # ys = task.evaluate(xs)

        Inputs_batch, Outputs_batch=Gen_data_SS_innovation_noise_obs_noise_F_options_discard_all_stats(device='cuda', batch_size=bsize, input_dim=model.n_dims-1, chunk_size=curriculum.n_points, w_sigma=1, x_sigma=1, d_curr=curriculum.n_dims_truncated, Dynamic=True, alpha_F=curriculum.F_alpha, alpha_Q=curriculum.Q_alpha, alpha_R=curriculum.R_alpha, F_option=option)

        loss_func = mean_squared_error

        loss, output = train_step_mine(model, Inputs_batch, Outputs_batch, optimizer, loss_func)

        point_wise_tags = list(range(curriculum.n_points))
        point_wise_loss_func = squared_error
        point_wise_loss = point_wise_loss_func(output, Outputs_batch[:,::2]).mean(dim=0)

        baseline_loss = (
            sum(
                max(curriculum.n_dims_truncated - ii, 0)
                for ii in range(curriculum.n_points)
            )
            / curriculum.n_points
        )

        if i % args.wandb.log_every_steps == 0 and not args.test_run:
            wandb.log(
                {
                    "overall_loss": loss,
                    "excess_loss": loss / baseline_loss,
                    "pointwise/loss": dict(
                        zip(point_wise_tags, point_wise_loss.cpu().numpy())
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                    "F_alpha": curriculum.F_alpha,
                    "Q_alpha":curriculum.Q_alpha,
                    "R_alpha":curriculum.R_alpha
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss}")
        if i % args.training.save_every_steps == 0 and not args.test_run:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
            }
            torch.save(training_state, state_path)

        if (
            args.training.keep_every_steps > 0
            and i % args.training.keep_every_steps == 0
            and not args.test_run
            and i > 0
        ):
            torch.save(model, os.path.join(args.out_dir, f"model_{i}.pt"))








def train_one_step_pred(model, args):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
    curriculum = Curriculum(args.training.curriculum)
    y_dim=curriculum.y_dim
    discard=curriculum.discard
    discard_mode=curriculum.discard_mode
    option=curriculum.option

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()

    n_dims = model.n_dims
    bsize = args.training.batch_size
    data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
    task_sampler = get_task_sampler(
        args.training.task,
        n_dims,
        bsize,
        num_tasks=args.training.num_tasks,
        **args.training.task_kwargs,
    )
    pbar = tqdm(range(starting_step, args.training.train_steps))

    num_training_examples = args.training.num_training_examples

    for i in pbar:

        Inputs_batch, Outputs_batch=Gen_data_One_Step(device='cuda', batch_size=bsize, input_dim=int(n_dims/model.y_dim-1), chunk_size=curriculum.n_points, w_sigma=1, x_sigma=1, d_curr=curriculum.n_dims_truncated, Dynamic=True, alpha_F=curriculum.F_alpha, alpha_Q=curriculum.Q_alpha, alpha_R=curriculum.R_alpha, F_option=option, y_dim=y_dim, discard=discard, discard_mode=discard_mode)


        loss_func = mean_squared_error_measurement

        loss, output = train_step_one_step_pred(model, Inputs_batch, Outputs_batch, optimizer, loss_func, y_dim=y_dim, discard=discard, discard_mode=discard_mode)

        point_wise_tags = list(range(curriculum.n_points))
        point_wise_loss_func = squared_error

        if not discard:
            point_wise_loss = point_wise_loss_func(output, Outputs_batch[:,2*(model.state_dim)+1::2,:]).mean(dim=0)

        elif discard and discard_mode=='All':
            point_wise_loss = point_wise_loss_func(output, Outputs_batch[:, ::2, :]).mean(dim=0)

        elif discard and discard_mode == 'Noise':
            point_wise_loss = point_wise_loss_func(output, Outputs_batch[:, model.state_dim::2, :]).mean(dim=0)

        baseline_loss = (
            sum(
                max(curriculum.n_dims_truncated - ii, 0)
                for ii in range(curriculum.n_points)
            )
            / curriculum.n_points
        )

        if i % args.wandb.log_every_steps == 0 and not args.test_run:
            wandb.log(
                {
                    "overall_loss": loss,
                    "excess_loss": loss / baseline_loss,
                    "pointwise/loss": dict(
                        zip(point_wise_tags, point_wise_loss.cpu().numpy())
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                    "F_alpha": curriculum.F_alpha,
                    "Q_alpha":curriculum.Q_alpha,
                    "R_alpha":curriculum.R_alpha
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss}")
        if i % args.training.save_every_steps == 0 and not args.test_run:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
            }
            torch.save(training_state, state_path)

        if (
            args.training.keep_every_steps > 0
            and i % args.training.keep_every_steps == 0
            and not args.test_run
            and i > 0
        ):
            torch.save(model, os.path.join(args.out_dir, f"model_{i}.pt"))

def train_one_step_pred_control(model, args):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
    curriculum = Curriculum(args.training.curriculum)
    y_dim=curriculum.y_dim
    discard=curriculum.discard
    discard_mode=curriculum.discard_mode
    option=curriculum.option
    control=curriculum.control

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()

    n_dims = model.n_dims
    bsize = args.training.batch_size
    data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
    task_sampler = get_task_sampler(
        args.training.task,
        n_dims,
        bsize,
        num_tasks=args.training.num_tasks,
        **args.training.task_kwargs,
    )
    pbar = tqdm(range(starting_step, args.training.train_steps))

    num_training_examples = args.training.num_training_examples

    for i in pbar:

        Inputs_batch, Outputs_batch=Gen_data_One_Step_with_Control(device='cuda', batch_size=bsize, input_dim=int(n_dims/model.y_dim-1), chunk_size=curriculum.n_points, w_sigma=1, x_sigma=1, d_curr=curriculum.n_dims_truncated, Dynamic=True, alpha_F=curriculum.F_alpha, alpha_Q=curriculum.Q_alpha, alpha_R=curriculum.R_alpha, F_option=option, y_dim=y_dim, discard=discard, discard_mode=discard_mode, control=control)


        loss_func = mean_squared_error_measurement

        loss, output = train_step_one_step_pred_control(model, Inputs_batch, Outputs_batch, optimizer, loss_func, y_dim=y_dim, discard=discard, discard_mode=discard_mode, control=control)

        point_wise_tags = list(range(curriculum.n_points))
        point_wise_loss_func = squared_error

        if control:
            if not discard:
                point_wise_loss = point_wise_loss_func(output, Outputs_batch[:, 3 * (model.state_dim) + 1::3, :]).mean(
                    dim=0)

            elif discard and discard_mode == 'All':
                point_wise_loss = point_wise_loss_func(output, Outputs_batch[:, ::3, :]).mean(dim=0)

            elif discard and discard_mode == 'Noise':
                point_wise_loss = point_wise_loss_func(output, Outputs_batch[:, 2*model.state_dim::3, :]).mean(dim=0)
        else:
            if not discard:
                point_wise_loss = point_wise_loss_func(output, Outputs_batch[:,2*(model.state_dim)+1::2,:]).mean(dim=0)

            elif discard and discard_mode=='All':
                point_wise_loss = point_wise_loss_func(output, Outputs_batch[:, ::2, :]).mean(dim=0)

            elif discard and discard_mode == 'Noise':
                point_wise_loss = point_wise_loss_func(output, Outputs_batch[:, model.state_dim::2, :]).mean(dim=0)

        baseline_loss = (
            sum(
                max(curriculum.n_dims_truncated - ii, 0)
                for ii in range(curriculum.n_points)
            )
            / curriculum.n_points
        )

        if i % args.wandb.log_every_steps == 0 and not args.test_run:
            wandb.log(
                {
                    "overall_loss": loss,
                    "excess_loss": loss / baseline_loss,
                    "pointwise/loss": dict(
                        zip(point_wise_tags, point_wise_loss.cpu().numpy())
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                    "F_alpha": curriculum.F_alpha,
                    "Q_alpha":curriculum.Q_alpha,
                    "R_alpha":curriculum.R_alpha
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss}")
        if i % args.training.save_every_steps == 0 and not args.test_run:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
            }
            torch.save(training_state, state_path)

        if (
            args.training.keep_every_steps > 0
            and i % args.training.keep_every_steps == 0
            and not args.test_run
            and i > 0
        ):
            torch.save(model, os.path.join(args.out_dir, f"model_{i}.pt"))


def main(args):

    curriculum = Curriculum(args.training.curriculum)
    y_dim=curriculum.y_dim
    discard=curriculum.discard
    discard_mode=curriculum.discard_mode
    control=curriculum.control

    if args.test_run:
        curriculum_args.points.start = curriculum_args.points.end
        curriculum_args.dims.start = curriculum_args.dims.end
        args.training.train_steps = 100
    else:
        wandb.login(key='916e3711e48f46e05df856d9305a511042eb09a0');
        wandb.init(
            dir=args.out_dir,
            project=args.wandb.project,
            entity=args.wandb.entity,
            config=args.__dict__,
            notes=args.wandb.notes,
            name=args.wandb.name,
            resume=True,

        )

    # model = build_model(args.model)
    model = build_model(args.model, y_dim=y_dim, discard=discard, discard_mode=discard_mode, control=control)
    model.cuda()
    model.train()

    # train_mine_SS_innovation_noise_obs_noise_F_options_non_scalar_y(model, args, option=3)
    train_one_step_pred_control(model,args)

    if not args.test_run:
        _ = get_run_metrics(args.out_dir)  # precompute metrics for eval


if __name__ == "__main__":
    parser = QuinineArgumentParser(schema=schema)
    args = parser.parse_quinfig()
    assert args.model.family in ["gpt2", "lstm"]
    print(f"Running with: {args}")

    if not args.test_run:
        run_id = args.training.resume_id
        if run_id is None:
            run_id = str(uuid.uuid4())

        out_dir = os.path.join(args.out_dir, run_id)
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        args.out_dir = out_dir

        with open(os.path.join(out_dir, "config.yaml"), "w") as yaml_file:
            yaml.dump(args.__dict__, yaml_file, default_flow_style=False)

    main(args)
