import datetime
import numpy as np
import torch
import wandb
import os
import argparse
import json
from config import cfg as default_cfg
from tqdm import tqdm 
from matplotlib.colors import LogNorm, Normalize, SymLogNorm
from wandb_fcs import wandb_log_lropt, wandb_log_trainerror, wandb_log_gradnrm

from config import cfg, REL_PATH_DATA_FLDR, REL_PATH_EXP_FLDR
from plot_fcs import plot_realfc_vs_time, wandb_log_heatmap
from linear_transformer import PosEnc
from kalman_filter import KalmanFilter, run_kf
from data_processing import create_tokens, get_train_data, extract_pred_and_label
from helpers import create_loss_fc, create_optimizer, create_lr_scheduler, get_lr, create_transformer, \
                    update_min_max, seqs_vs_ground_truth, extract_att_params, extract_embed_params, \
                    update_att_param_lists, update_embed_param_lists, \
                    set_seeds, RLSSingle, RLS, run_rls, get_script_dir
from test_gd_vs_model import get_stable_state_kf, get_gd_steps_history_p, \
                                distance_and_cos, learned_tf_gd_output


def train(train_inputs, train_targets, cfg, val_inputs, val_targets, val_set, wandb_writer):
    device = cfg["device"]
    batch_sz = cfg["batch_sz"] 
    assert train_inputs.shape[0] % batch_sz == 0
    n_iter = min(cfg["max_iters"], train_inputs.shape[0] // batch_sz) if cfg["max_iters"] is not None else train_inputs.shape[0] // batch_sz

    model = create_transformer(cfg=cfg).to(device=device)
    optimizer = create_optimizer(cfg, model=model)
    lr_scheduler = create_lr_scheduler(cfg, optimizer)
    loss_fc = create_loss_fc(cfg)
    
    freq_heatmap = 500
    train_loss = np.zeros((n_iter,))
    grad_norm = np.zeros((n_iter,))
    err_wrt_gd = np.zeros((n_iter,))
    cos_wrt_gd = np.zeros((n_iter,))
    lr_schedule = np.zeros((n_iter,))

    IOlinmaps, min_IOlinmap, max_IOlinmap = [], None, None
    POSATs, min_POSAT, max_POSAT = [], None, None
    WVs, min_WV, max_WV = [], None, None
    ATs, min_AT, max_AT = [], None, None
    QKs, min_QK, max_QK = [], None, None
    WQKs, min_WQK, max_WQK = [], None, None
    Prjs, min_Prj, max_Prj = [], None, None
    input_embeds, min_input_embed, max_input_embed = [], None, None
    output_embeds, min_output_embed, max_output_embed = [], None, None

    Cs = val_set["meas"].to(device=device)
    As = val_set["trans"].to(device=device)
    is_diag = val_set["is_diag"]
    Ks, _ = get_stable_state_kf(val_set)
    Ks = Ks.to(device=device)

    #if cfg["camera_ready"]:
    #    #test_gd_preds = get_gd_steps_history_p(is_diag, val_inputs, As, Cs, Ks[:, -1, :], 29)
    #    test_gd_preds = learned_tf_gd_output(val_set["seqs"][:, :-1, :], val_set["seqs"][:, 1:, :], is_diag, As, Cs, Ks[:, -1, :], 1)
    #    test_gd_preds = test_gd_preds.to(device=device)
    
    print("Training begins with {} iterations of batch size {} ... ".format(n_iter, batch_sz))
    model.train()
    for i in tqdm(range(n_iter)):
        ### generate inputs and targets
        batch_idxs = range(i * batch_sz, (i + 1) * batch_sz)
        train_batch = train_inputs[batch_idxs, :, :]
        target_batch = train_targets[batch_idxs, :, :]
        inputs = train_batch.to(device)
        targets = target_batch.to(device)
                               
        ### forward pass
        optimizer.zero_grad()
        outputs, transf_params = model(inputs)
        preds, labels = extract_pred_and_label(outputs, targets, cfg)
        loss = 1/(2 * batch_sz * train_batch.shape[-2]) * loss_fc(preds, labels)

        ### backward and train
        loss.backward()
        if cfg["max_gr_nrm"] is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["max_gr_nrm"])
        optimizer.step()

        if cfg["schedule_lr"]:
            lr_scheduler.step()
        
        grads = [
            param.grad.detach().flatten() for param in model.parameters() if param.grad is not None]
        total_norm = torch.cat(grads).norm()

        ### Record losses & stuff
        if (i + 1) * batch_sz % cfg["train_record_freq"] == 0:
            with torch.no_grad():
                val_outputs, _ = model(val_inputs)
                val_preds, val_labels = extract_pred_and_label(val_outputs, val_targets, cfg)
                train_loss[i] = 1/(2 * batch_sz * train_batch.shape[1]) * loss_fc(val_preds, val_labels)
                grad_norm[i] = total_norm
                lr_schedule[i] = get_lr(optimizer)

                #if cfg["camera_ready"]:
                #    err_wrt_gd[i], cos_wrt_gd[i] = distance_and_cos(test_tf_preds[:, -1, :], test_gd_preds[:, -1, :])
            

            if wandb_writer is not None:
                wandb_log_trainerror(wandb_writer, i, train_loss[i])
                wandb_log_lropt(wandb_writer, i, lr_schedule[i])
                wandb_log_gradnrm(wandb_writer, i, grad_norm[i])

            if cfg["camera_ready"] and i % freq_heatmap == 0:
                # TODO: here cpu needs to be switched to device
                # TODO: this below will not come out right
                W_V, W_QK, PA, POS_AT, AT, QK, Prj, IO_linmap_h = extract_att_params(transf_params["lyr_1"]["att"])
                # print("IO_linmap_h's shape: ", IO_linmap_h.shape)
                IO_linmap = torch.zeros(IO_linmap_h.shape[1], IO_linmap_h.shape[2])
                num_h = IO_linmap_h.shape[0]
                for h in range(num_h):
                    IO_linmap = IO_linmap + IO_linmap_h[h, :, :]
                # print("IO_linmap's shape: ", IO_linmap.shape)
                update_att_param_lists(WQKs, W_QK, IOlinmaps, IO_linmap, POSATs, POS_AT, 
                                                        WVs, W_V, ATs, AT, QKs, QK, Prjs, Prj)
                if cfg["extra_input_lin_layer"]:
                    input_embed, output_embed = extract_embed_params(transf_params)
                    update_embed_param_lists(input_embeds, input_embed, output_embeds, output_embed)
                '''
                min_WV, max_WV = update_min_max(W_V, min_WV, max_WV)
                min_IOlinmap, max_IOlinmap = update_min_max(IO_linmap, min_IOlinmap, max_IOlinmap)
                min_AT, max_AT = update_min_max(AT, min_AT, max_AT)
                min_POSAT, max_POSAT = update_min_max(POS_AT, min_POSAT, max_POSAT)
                min_QK, max_QK = update_min_max(QK, min_QK, max_QK)
                min_WQK, max_WQK = update_min_max(W_QK, min_WQK, max_WQK)
                min_Prj, max_Prj = update_min_max(Prj, min_Prj, max_Prj)
                min_input_embed, max_input_embed = update_min_max(input_embed, min_input_embed, max_input_embed)
                min_output_embed, max_output_embed = update_min_max(output_embed, min_output_embed, max_output_embed) 
                '''
    if cfg["camera_ready"]:
        print("Processing heatmaps.")
        special_step_name="Iter. " + str(freq_heatmap) + "ths"
        for i in tqdm(range(0, len(WVs))):
            # TODO: why is this thing working when I pass the worng vmin vmax to symlognorm?
            # wandb_log_heatmap(wandb_writer, IOlinmaps[i], min_IOlinmap, max_IOlinmap, "IO lin. map, calibrated", norm=SymLogNorm(1e-5), special_step=i, special_step_name=special_step_name)
            wandb_log_heatmap(wandb_writer, WVs[i], min_WV, max_WV, "W^V, calibrated", norm=SymLogNorm(1e-7), annot=True, special_step=i, special_step_name=special_step_name)
            if cfg["pos_enc_type"] is not None:
               wandb_log_heatmap(wandb_writer, POSATs[i], min_POSAT, max_POSAT, "phi(Q_POS @ K_POS^T), calibrated", norm=SymLogNorm(1e-5), special_step=i, special_step_name=special_step_name)
            wandb_log_heatmap(wandb_writer, ATs[i], min_AT, max_AT, "phi(QK^T), calibrated", norm=Normalize(), special_step=i, special_step_name=special_step_name)     
            wandb_log_heatmap(wandb_writer, QKs[i], min_QK, max_QK, "QK^T, calibrated", norm=Normalize(), special_step=i, special_step_name=special_step_name)
            wandb_log_heatmap(wandb_writer, WQKs[i], min_WQK, max_WQK, "W^QK, calibrated", norm=Normalize(), annot=True, special_step=i, special_step_name=special_step_name)
            if cfg["projection"]:   
               wandb_log_heatmap(wandb_writer, Prjs[i], min_Prj, max_Prj, "Prj, calibrated", norm=Normalize(), annot=True, special_step=i, special_step_name=special_step_name)
            if cfg["extra_input_lin_layer"]:
                wandb_log_heatmap(wandb_writer, input_embeds[i], min_input_embed, max_input_embed, "Input embed, calibrated", norm=Normalize(), annot=True, special_step=i, special_step_name=special_step_name)
                wandb_log_heatmap(wandb_writer, output_embeds[i], min_output_embed, max_output_embed, "Output embed, calibrated", norm=Normalize(), annot=True, special_step=i, special_step_name=special_step_name)

    print("Training is done.")
    results = {
        "train_loss" : {
            "data": train_loss,
            "label": 'Pop. loss'
        },
        "err_wrt_gd1" : {
            "data": err_wrt_gd,
            "label": 'err_vs_gd'
        },
        "cos_wrt_gd1" : {
            "data" : cos_wrt_gd,
            "label" : cos_wrt_gd
        },
        "grad_norm" : {
            "data": grad_norm,
            "label": "grad RMSE"
        },
        "lr_schedule" :{
            "data": lr_schedule,
            "label": "lr. schedule"
        }
    }

    return model, results



############# MAIN
if __name__ == '__main__':
    ############# Parse sweep configs if given as argument
    parser = argparse.ArgumentParser()
    parser.add_argument("--sweep_cfg", type=str, default=None)
    parser.add_argument("--run_id", type=str, default=None)
    parser.add_argument("--wandb_group", type=str, default=None)
    args, _ = parser.parse_known_args()

    if args.sweep_cfg:
        sweep_cfg = json.loads(args.sweep_cfg)
        default_cfg.update(sweep_cfg)
    cfg = default_cfg

    if args.run_id:
        cfg["run_id"] = args.run_id

    ############# Coherence checks
    assert cfg["seq_len"] <= 100, "If you need a larger seq here, you need to change the pos enc max"
    assert cfg["pos_enc_type"] != PosEnc.FIXED_ONE_HOT_CONCAT, "Can't have this pos encoding for now."
    assert cfg["batch_sz"] <= 1000, "Batchsize cannot be more than 256."
    
    ############# Fix randomness
    set_seeds(cfg["rand_seed"])

    ############# Set device
    device = cfg["device"]
    print("\n*** Device: " + str(cfg["device"]) + ".\n")

    ############# Set output paths
    DATA_FLDR = get_script_dir() + REL_PATH_DATA_FLDR
    EXP_FLDR = get_script_dir() + REL_PATH_EXP_FLDR
    saved_model = False
    saved_data = False

    os.environ['WANDB_INIT_TIMEOUT'] = '120'
    # wandb_writer = wandb.init(project="Learning distributions", config=cfg, id=args.run_id, group=args.wandb_group)
    wandb_writer = wandb.init(project="Learning distributions", config=cfg)
    # wandb_writer = None

    # ############## Set up time series data generation
    cfg["max_iters"] = 10
    train_size = 10 * 1000
    val_size = 10 #int(0.1 * train_size)

    ############# Generate data & create tokens    
    train_data, val_data = get_train_data(saved_data, cfg, train_size, val_size)
    inputs, labels = create_tokens(train_data["seqs"], cfg)
    val_inputs, val_labels = create_tokens(val_data["seqs"], cfg)
    print("\n**** Done augmenting tokens. Initial token dimension = {}, new token dimension = {}.\n".format(cfg["output_dim"], inputs.shape[-1]))

    ############# Train
    saved_filename = "noisy_transformer.pt"
    if not saved_model:
        model, results = train(inputs, labels, cfg, val_inputs, val_labels, val_data, wandb_writer)
        if "run_id" in cfg:
            saved_filename = str(cfg["run_id"]) + "_" + saved_filename
        torch.save((cfg, model, results), DATA_FLDR + "/" + saved_filename)
    else:
        (cfg, model, results) = torch.load('saved_models/noisy_transformer.pt')

    # klm_f = KalmanFilter(device=device)
    # train_K_t, train_P_t, train_y_hat, train_y_hat_forward, train_x_hat_y = run_kf(train_data, klm_f, device=device)

    if not cfg["camera_ready"]:
        exit()
    
    plot_path = EXP_FLDR + "/"
    plot_realfc_vs_time(plot_path, [results["train_loss"]["data"]], [results["train_loss"]["label"]], extra_title_info=str(datetime.datetime.now()), x_label='k (iterations)', y_label=r"$\mathcal{L}(X)$")
    plot_realfc_vs_time(plot_path, [results["grad_norm"]["data"]], [results["grad_norm"]["label"]], extra_title_info=str(datetime.datetime.now()), x_label='k (iterations)', y_label=r"$\|\nabla TF(\theta)\|$", y_scale='linear')
    plot_realfc_vs_time(plot_path, [results["lr_schedule"]["data"]], [results["lr_schedule"]["label"]], extra_title_info=str(datetime.datetime.now()), x_label='k (iterations)', y_label=r"$LR$", y_scale='linear')

    
    exit()

    ############# Compare with KF on test sequences
    ys = val_data["seqs"]
    
    #test_inputs = ys[:, :-1, :]
    #aug_test_inputs = augmented_tokens(test_inputs, 2, device)

    if not cfg["shuffle"]:
        tf_preds, _ = model(ys[:, :-1, :].to(device=device))
        #aug_tf_preds, _ = model(aug_test_inputs.to(device=device))
        #dim_y = ys.shape[2]
        #tf_preds = aug_tf_preds[:, :, :dim_y]
        print("The output of the TF: ", tf_preds[0, :, :])
        #tf_preds_proj_f_PE_f, _ = model_proj_f_PE_f(ys[:, :-1, :].to(device=device))
        #tf_preds_proj_f_PE_t, _ = model_proj_f_PE_t(ys[:, :-1, :].to(device=device))
        #tf_preds_proj_t_PE_f, _ = model_proj_t_PE_f(ys[:, :-1, :].to(device=device))
        #tf_preds_proj_t_PE_t, _ = model_proj_t_PE_t(ys[:, :-1, :].to(device=device))
    else:
        # Schuffle the order of the input sequences
        seq_len = ys.shape[1] - 1
        tf_preds = torch.zeros(ys.shape[0], seq_len, ys.shape[2], device=device)
        for i in range(seq_len):
            ys_new = ys[:, :-1, :].clone()
            ys_unshuffled = ys[:, :i, :].clone()
            if i < 7:
                print("i: ", i)
                print("ys_new before shuffling: ", ys_new[0, :, 0])
            ys_shuffled = ys_unshuffled[:, torch.randperm(ys_unshuffled.size(1)), :]
            #ys_shuffled = ys_unshuffled
            ys_new[:, :i, :] = ys_shuffled
            if i < 7:
                print("ys_new after shuffling: ", ys_new[0, :, 0])
                print("ys_new[i]: ", ys_new[0, i, 0])
            #print("The difference between ys_schuffled and ys_unschuffled", ys_shuffled - ys_unshuffled)
            #output_ys, _ = model(ys_new[:, :i+1, :].to(device=device))
            output_ys, _ = model(ys_new.to(device=device))
            tf_preds[:, i, :] = output_ys[:, i, :]
        print("tf_preds: ", tf_preds[0, :, 0])

    untrained_model = create_transformer(cfg=cfg).to(device)
    tf_preds_untrained, _ = untrained_model(ys[:, :-1, :].to(device=device))
    #aug_tf_preds_untrained, _ = untrained_model(aug_test_inputs.to(device=device))
    #tf_preds_untrained = aug_tf_preds_untrained[:, :, :dim_y]
    
    test_ys = val_data["seqs"][:, :-1, :]
    test_labels = val_data["seqs"][:, 1:, :]
    Cs = val_data["meas"]
    As = val_data["trans"]
    is_diag = val_data["is_diag"]
    Ds = torch.ones_like(Cs)
    Ks, _ = get_stable_state_kf(val_data)

    #test_gd_preds = get_gd_steps_history_p(is_diag, ys[:, :-1, :], As, Cs, Ks[:, -1, :], 1)
    #test_gd_preds = learned_tf_gd_output(ys[:, :-1, :], ys[:, 1:, :], is_diag, As, Cs, Ks[:, -1, :], 1)
    #test_gd_preds = test_gd_preds.to(device=device)
    _, _, kf_preds, kf_preds_forward, _ = run_kf(val_data, klm_f, device)
    rls_preds = torch.from_numpy(run_rls(test_ys, 2))
    #print("The length of rls_preds: ", rls_preds.shape[1])
    
    [err_tf, err_kf, err_kf_forward, ynorm, err_tf_untrained, err_rls], [err_tf_avg, err_kf_avg, err_kf_forward_avg, ynorm_avg, err_tf_untrained_avg, err_rls_avg] = seqs_vs_ground_truth(
                        [tf_preds.detach(), kf_preds[:, 1:, :], kf_preds_forward[:, :-1, :], torch.zeros_like(kf_preds[:, 1:, :]), tf_preds_untrained.detach(), rls_preds[:, :, :]], 
                        ys[:, 1:, :])
    
    #[err_tf, ynorm_avg, err_tf_untrained] = seqs_vs_ground_truth(
    #                    [tf_preds.detach(), torch.zeros_like(tf_preds.detach()), tf_preds_untrained.detach()], 
    #                    ys[:, 2:, :])
    
    #[err_tf_proj_f_PE_f, err_tf_proj_t_PE_f, err_tf_proj_f_PE_t, err_tf_proj_t_PE_t, err_kf, err_kf_forward, ynorm_avg] = seqs_vs_ground_truth(
    #                    [tf_preds_proj_f_PE_f.detach(), tf_preds_proj_t_PE_f.detach(), tf_preds_proj_f_PE_t.detach(), tf_preds_proj_t_PE_t.detach(), kf_preds[:, 1:, :], kf_preds_forward[:, :-1, :], torch.zeros_like(kf_preds[:, 1:, :])],
    #                    ys[:, 1:, :])

    # With fixed TF weight configurations as GD 
    #[err_tf, err_kf, err_kf_forward, ynorm_avg, err_tf_untrained, err_tf_gd] = seqs_vs_ground_truth(
    #                    [tf_preds.detach(), kf_preds[:, 1:, :], kf_preds_forward[:, 1:, :], torch.zeros_like(kf_preds[:, 1:, :]), tf_preds_untrained.detach(), test_gd_preds.detach()], 
    #                    ys[:, 1:, :])
    '''
    no_rs = cfg["test_rand_seed_number"]
    err_tf_all = torch.zeros(no_rs, cfg["seq_len"], device=device)
    err_kf_all = torch.zeros(no_rs, cfg["seq_len"], device=device)
    err_kf_forward_all = torch.zeros(no_rs, cfg["seq_len"], device=device)
    ynorm_avg_all = torch.zeros(no_rs, cfg["seq_len"], device=device)
    err_tf_untrained_all = torch.zeros(no_rs, cfg["seq_len"], device=device)
    err_rls_all = torch.zeros(no_rs, cfg["seq_len"], device=device)
    for i in range(no_rs):
        set_seeds(cfg["rand_seed"] + 100 * i)
        _, test_data, _ = get_train_data(saved_data, cfg, 
                                            1, val_size, s_noise_var, m_noise_var, NOISE_VAR)

        test_ys = test_data["seqs"][:, :-1, :]
        test_labels = test_data["seqs"][:, 1:, :]
        Cs = test_data["meas"]
        As = test_data["trans"]
        is_diag = test_data["is_diag"]
        Ds = torch.ones_like(Cs)
        Ks, _ = get_stable_state_kf(test_data)
    
        #test_gd_preds = get_gd_steps_history_p(is_diag, ys[:, :-1, :], As, Cs, Ks[:, -1, :], 1)
        #test_gd_preds = learned_tf_gd_output(ys[:, :-1, :], ys[:, 1:, :], is_diag, As, Cs, Ks[:, -1, :], 1)
        #test_gd_preds = test_gd_preds.to(device=device)
        _, _, kf_preds, kf_preds_forward, _ = run_kf(test_data, klm_f, device)
        rls_preds = torch.from_numpy(run_rls(test_ys, 2))
        #print("The length of rls_preds: ", rls_preds.shape[1])
        
        [err_tf, err_kf, err_kf_forward, ynorm_avg, err_tf_untrained, err_rls] = seqs_vs_ground_truth(
                            [tf_preds.detach(), kf_preds[:, 1:, :], kf_preds_forward[:, :-1, :], torch.zeros_like(kf_preds[:, 1:, :]), tf_preds_untrained.detach(), rls_preds[:, :, :]], 
                            ys[:, 1:, :])
        
        #[err_tf, ynorm_avg, err_tf_untrained] = seqs_vs_ground_truth(
        #                    [tf_preds.detach(), torch.zeros_like(tf_preds.detach()), tf_preds_untrained.detach()], 
        #                    ys[:, 2:, :])
        
        #[err_tf_proj_f_PE_f, err_tf_proj_t_PE_f, err_tf_proj_f_PE_t, err_tf_proj_t_PE_t, err_kf, err_kf_forward, ynorm_avg] = seqs_vs_ground_truth(
        #                    [tf_preds_proj_f_PE_f.detach(), tf_preds_proj_t_PE_f.detach(), tf_preds_proj_f_PE_t.detach(), tf_preds_proj_t_PE_t.detach(), kf_preds[:, 1:, :], kf_preds_forward[:, :-1, :], torch.zeros_like(kf_preds[:, 1:, :])],
        #                    ys[:, 1:, :])
    
        # With fixed TF weight configurations as GD 
        #[err_tf, err_kf, err_kf_forward, ynorm_avg, err_tf_untrained, err_tf_gd] = seqs_vs_ground_truth(
        #                    [tf_preds.detach(), kf_preds[:, 1:, :], kf_preds_forward[:, 1:, :], torch.zeros_like(kf_preds[:, 1:, :]), tf_preds_untrained.detach(), test_gd_preds.detach()], 
        #                    ys[:, 1:, :])
        
        err_tf_all[i, :] = err_tf
        err_kf_all[i, :] = err_kf
        err_kf_forward_all[i, :] = err_kf_forward
        ynorm_avg_all[i, :] = ynorm_avg
        err_tf_untrained_all[i, :] = err_tf_untrained
        err_rls_all[i, :] = err_rls
    '''
    
    #plot_realfc_vs_time(plot_path, [err_tf, err_kf, ynorm_avg], ["err_tf", "err_kf", "ynorm_avg"], extra_title_info=str(datetime.datetime.now()), x_label='seq. len.', y_label=r"$\|\hat{y} - y\|$", y_scale='log')
    #plot_realfc_vs_time(plot_path, [err_kf], ["err_kf"], extra_title_info=str(datetime.datetime.now()), x_label='seq. len.', y_label=r"$\|\hat{y} - y\|$", y_scale='log')
    if cfg["plot_tube"]:
        plot_realfc_vs_time(plot_path, [err_tf, err_kf, err_kf_forward, err_tf_untrained, err_rls, ynorm], ["err_tf", "err_kf", "err_kf_forward", "err_tf_untrained", "err_rls", "ynorm"], extra_title_info=str(datetime.datetime.now()), x_label='seq. len.', y_label=r"$\|\hat{y} - y\|$", y_scale='log')
    else:
        plot_realfc_vs_time(plot_path, [err_tf_avg, err_kf_avg, err_kf_forward_avg, err_tf_untrained_avg, err_rls_avg, ynorm_avg], ["err_tf", "err_kf", "err_kf_forward", "err_tf_untrained", "err_rls", "ynorm"], extra_title_info=str(datetime.datetime.now()), x_label='seq. len.', y_label=r"$\|\hat{y} - y\|$", y_scale='log')
    plot_realfc_vs_time(plot_path, [err_tf_avg/ynorm_avg, err_kf_avg/ynorm_avg, err_kf_forward_avg/ynorm_avg, err_tf_untrained_avg/ynorm_avg, err_rls_avg/ynorm_avg], ['tf/ynrm', 'kf/ynrm', 'kf_forward/ynrm', 'tf_untr/ynrm', 'rls/ynrm'], extra_title_info=str(datetime.datetime.now()), x_label='seq. len.', y_label=r"$\|\hat{y} - y\|$", y_scale='log')

    #plot_realfc_vs_time(plot_path, [err_tf_all, err_kf_all, err_kf_forward_all, err_tf_untrained_all, err_rls_all, ynorm_avg_all], ["err_tf", "err_kf", "err_kf_forward", "err_tf_untrained", "err_rls", "ynorm_avg"], extra_title_info=str(datetime.datetime.now()), x_label='seq. len.', y_label=r"$\|\hat{y} - y\|$", y_scale='log')
    #plot_realfc_vs_time(plot_path, [err_tf/ynorm_avg, err_kf/ynorm_avg, err_kf_forward/ynorm_avg, err_tf_untrained/ynorm_avg, err_rls/ynorm_avg], ['tf/ynrm', 'kf/ynrm', 'kf_forward/ynrm', 'tf_untr/ynrm', 'rls/ynrm'], extra_title_info=str(datetime.datetime.now()), x_label='seq. len.', y_label=r"$\|\hat{y} - y\|$", y_scale='log')

    #plot_realfc_vs_time(plot_path, [err_tf, err_tf_untrained, ynorm_avg], ["err_tf", "err_tf_untrained", "ynorm_avg"], extra_title_info=str(datetime.datetime.now()), x_label='seq. len.', y_label=r"$\|\hat{y} - y\|$", y_scale='log')

    #plot_realfc_vs_time(plot_path, [err_tf_proj_f_PE_f, err_tf_proj_t_PE_f, err_tf_proj_f_PE_t, err_tf_proj_t_PE_t, err_kf, err_kf_forward, ynorm_avg], ["err_tf_proj_f_PE_f", "err_tf_proj_t_PE_f", "err_tf_proj_f_PE_t", "err_tf_proj_t_PE_t", "err_kf", "err_kf_forward", "ynorm_avg"], extra_title_info=str(datetime.datetime.now()), x_label='seq. len.', y_label=r"$\|\hat{y} - y\|$", y_scale='log')

    ######### With fixed TF weight configurations as GD 
    #plot_realfc_vs_time(plot_path, [err_tf, err_kf, err_kf_forward, err_tf_untrained, err_tf_gd, ynorm_avg], ["err_tf", "err_kf", "err_kf_forward", "err_tf_untrained", "err_tf_gd", "ynorm_avg"], extra_title_info=str(datetime.datetime.now()), x_label='seq. len.', y_label=r"$\|\hat{y} - y\|$", y_scale='log')
    #plot_realfc_vs_time(plot_path, [err_tf/ynorm_avg, err_kf/ynorm_avg, err_kf_forward/ynorm_avg, err_tf_gd/ynorm_avg, err_tf_untrained/ynorm_avg], ['tf/ynrm', 'kf/ynrm', 'kf_forward/ynrm', 'tf_gd/ynrm', 'tf_untr/ynrm'], extra_title_info=str(datetime.datetime.now()), x_label='seq. len.', y_label=r"$\|\hat{y} - y\|$", y_scale='log')
    
    #plot_realfc_vs_time(plot_path, [err_kf/ynorm_avg[1:]], ["err_kf_seqlen={}".format(optimizer_cfg["seq_len"])], extra_title_info=str(datetime.datetime.now()), x_label='seq. len.', y_label=r"$\|\hat{y} - y\|$", y_scale='log')
    #plot_realfc_vs_time(plot_path, [err_kf2], ["err_kf"], extra_title_info=str(datetime.datetime.now()), x_label='seq. len.', y_label=r"$\|\hat{y} - y\|$", y_scale='linear')
    
    #plot_realfc_vs_time(plot_path, [results["err_wrt_gd1"]["data"]], [results["err_wrt_gd1"]["label"]], extra_title_info=str(datetime.datetime.now()), x_label='k (iterations)', y_label=r"$\|y^{transf} - y^{GD}\|$")
    #plot_realfc_vs_time(plot_path, [results["cos_wrt_gd1"]["data"]], [results["cos_wrt_gd1"]["label"]], extra_title_info=str(datetime.datetime.now()), x_label='k (iterations)', y_label=r"$\cos(y^{transf}, y^{GD})$", y_scale='linear')
    
    #plot_realfc_vs_time(EXP_FLDR + "/", [err_kf/ynorm_avg[1:]], ["err_kf_seqlen={}".format(optimizer_cfg["seq_len"])], extra_title_info=str(datetime.datetime.now()), x_label='seq. len.', y_label=r"$\|\hat{y} - y\|$", y_scale='log')
    #plot_realfc_vs_time(EXP_FLDR + "/", [err_kf2], ["err_kf"], extra_title_info=str(datetime.datetime.now()), x_label='seq. len.', y_label=r"$\|\hat{y} - y\|$", y_scale='linear')
    
    #plot_realfc_vs_time(EXP_FLDR + "/", [results["train_loss"]["data"]], [results["train_loss"]["label"]], extra_title_info=str(datetime.datetime.now()), x_label='k (iterations)', y_label=r"$\mathcal{L}(X)$")
    #plot_realfc_vs_time(EXP_FLDR + "/", [results["grad_norm"]["data"]], [results["grad_norm"]["label"]], extra_title_info=str(datetime.datetime.now()), x_label='k (iterations)', y_label=r"$\|\nabla TF(\theta)\|$", y_scale='linear')
    #plot_realfc_vs_time(EXP_FLDR + "/", [results["lr_schedule"]["data"]], [results["lr_schedule"]["label"]], extra_title_info=str(datetime.datetime.now()), x_label='k (iterations)', y_label=r"$LR$", y_scale='linear')
    
    #plot_realfc_vs_time(EXP_FLDR + "/", [results["err_wrt_gd1"]["data"]], [results["err_wrt_gd1"]["label"]], extra_title_info=str(datetime.datetime.now()), x_label='k (iterations)', y_label=r"$\|y^{transf} - y^{GD}\|$")
    #plot_realfc_vs_time(EXP_FLDR + "/", [results["cos_wrt_gd1"]["data"]], [results["cos_wrt_gd1"]["label"]], extra_title_info=str(datetime.datetime.now()), x_label='k (iterations)', y_label=r"$\cos(y^{transf}, y^{GD})$", y_scale='linear')
    
    # Compute the Kalman filter mapping for the "translation" task
    kf_map = KalmanFilter.compute_KF_linear_map(train_data["seqs"].shape[1] - 1, train_K_t[0, :, :], train_data["trans"][0, :], train_data["meas"][0, :], is_diag)
    
    if wandb_writer is not None:
        wandb_log_heatmap(wandb_writer, kf_map.cpu(), np.min(kf_map.cpu().numpy()), np.max(kf_map.cpu().numpy()), "KF map heatmap")
        wandb_log_heatmap(wandb_writer, kf_map.cpu(), np.min(kf_map.cpu().numpy()), np.max(kf_map.cpu().numpy()), "KF map heatmap lognorm", norm=SymLogNorm(linthresh=1e-4, vmin=-1, vmax=1))

    
    
