import numpy as np
from itertools import chain
from itertools import groupby
import torch
import argparse
import losses
import spaces
import invertible_network_utils
import torch.nn.functional as F
import random
import os
import latent_spaces
import encoders
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import r2_score
from sklearn import kernel_ridge, linear_model
from sklearn.neural_network import MLPRegressor
import string
from scipy.stats import wishart
import csv
import utils
# import lasso_reg # NOTE: this part of code is kindly granted by Fumero et al, 2023
import yaml
from pathlib import Path
import json

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print("device:", device)


def valid_str(v):
    if hasattr(v, "__name__"):
        return valid_str(v.__name__)
    if isinstance(v, tuple) or isinstance(v, list):
        return "-".join([valid_str(x) for x in v])
    str_v = str(v).lower()
    valid_chars = "-_%s%s" % (string.ascii_letters, string.digits)
    str_v = "".join(c if c in valid_chars else "-" for c in str_v)
    return str_v


def get_exp_name(
    args,
    parser,
    blacklist=["evaluate", "num_train_batches", "num_eval_batches", "evaluate_iter"],
):
    exp_name = ""
    for x in vars(args):
        if getattr(args, x) != parser.get_default(x) and x not in blacklist:
            if isinstance(getattr(args, x), bool):
                exp_name += ("_" + x) if getattr(args, x) else ""
            else:
                exp_name += "_" + x + valid_str(getattr(args, x))
    return exp_name.lstrip("_")


# ------ store content and style dict into args for global use
# --------------------------------------------------------------------------
def update_args(args):
    """
    update initial args with computed subsets and corresponding latent style variables
    """
    zs_views = torch.tensor(
        args.S_k
    )  # [n_views, n_sk] # the view-specific latents as given in args.

    # retrieve powerset, content dict and style dict for all subsets and views
    (
        powerset,
        powerset_indicators,
        content_dict,
        style_dict,
    ) = utils.content_style_from_subsets(
        views=range(args.view_k),
        zs=zs_views,
        only_consider_whole_set=args.only_consider_whole_set,
    )  # if only consider the whole set, this reduce to theorem 1, if mode != soft
    args.powerset_indicators = powerset_indicators
    args.powerset = powerset  # list of lists with len = num_subsets
    args.content_dict = content_dict
    args.style_dict = style_dict

    # store content size, for the mode: known content size
    content_size_dict = {}
    for k, v in content_dict.items():
        content_size_dict[k] = len(v)
    args.content_size_dict = content_size_dict

    # make sure the number of latents align with Sk
    zn_set = list(set(chain.from_iterable(args.S_k)))
    args.z_n = len(zn_set)
    return args


def load_config_dict():
    config_dict = yaml.safe_load(Path("configs/fzoo.yaml").read_text())

    config_solver = utils.ConfigDict(config_dict["solver"])
    config_model = utils.ConfigDict(config_dict["model"])
    return config_solver, config_model


# ---------- initialisation functions ----------------------
# ----------------------------------------------------------
def init_or_load_mixing_functions(device, args, latent_dim=None):
    # Invertible MLP requires the same input and the same output size
    # extend to multi-view case
    F = torch.nn.ModuleList()  # set of mixing functions, not trainable after generated.
    for i in range(args.view_k):
        f_i = invertible_network_utils.construct_invertible_mlp(
            # n=args.content_n + args.style_n,
            n=latent_dim or len(args.S_k[i]),
            n_layers=args.n_mixing_layer,
            cond_thresh_ratio=0.001,
            n_iter_cond_thresh=25000,
        )
        F.append(f_i)
    if args.load_F is not None:
        F = torch.nn.ModuleList()
        for file_name in sorted(os.listdir(args.load_F)):
            if file_name.endswith(".pth"):
                f_i = invertible_network_utils.construct_invertible_mlp(
                    n=latent_dim or len(args.S_k[i]),
                    n_layers=args.n_mixing_layer,
                    cond_thresh_ratio=0.001,
                    n_iter_cond_thresh=25000,
                )
                model_path = os.path.join(args.load_F, file_name)
                f_i.load_state_dict(torch.load(model_path))
                F.append(f_i)
    for f_i in F:
        f_i.to(device)
        for p in f_i.parameters():
            p.requires_grad = False
    if args.shared_mixing_function:
        F = [F[0]] * args.view_k
    return F


def init_or_load_encoder_models(device, args, latent_dim=None):
    """Define encoders, trainable"""
    G = torch.nn.ModuleList()
    for i in range(args.view_k):
        g_i = encoders.get_mlp(
            n_in=latent_dim or len(args.S_k[i]),
            n_out=latent_dim or len(args.S_k[i]),
            layers=[
                len(args.S_k[i]) * 10,
                len(args.S_k[i]) * 50,
                len(args.S_k[i]) * 50,
                len(args.S_k[i]) * 50,
                len(args.S_k[i]) * 50,
                len(args.S_k[i]) * 10,
            ],
        )
        G.append(g_i)
        g_i.to(device)
    if args.load_G is not None:
        G = torch.nn.ModuleList()

        save_path = os.path.join(args.load_G, f"model.pt")
        print(save_path)
        ckpt = torch.load(save_path)

        for i in range(args.view_k):
            g_i = encoders.get_mlp(
                n_in=latent_dim or len(args.S_k[i]),
                n_out=latent_dim or len(args.S_k[i]),
                layers=[
                    len(args.S_k[i]) * 10,
                    len(args.S_k[i]) * 50,
                    len(args.S_k[i]) * 50,
                    len(args.S_k[i]) * 50,
                    len(args.S_k[i]) * 50,
                    len(args.S_k[i]) * 10,
                ],
            )
            g_i.load_state_dict(ckpt[f"encoder_{i}_state_dict"])
            g_i.to(device)
            G.append(g_i)
    if args.shared_encoder:
        G = [G[0]] * args.view_k
    return G


def init_or_load_training_models(F, G, device, args):
    """Define hat{z} = h(z) with h = g \circ f"""
    H = torch.nn.ModuleList()
    for i in range(args.view_k):
        f_i = F[i]
        g_i = G[i]
        h_i = torch.nn.Sequential(
            *list(f_i) + list(g_i)
        )  #: g_i(f_i(z)) # f: mixing function; g: encoder
        H.append(h_i)

    # torch.nn.Module wrapper for encoder-mixing_function composition
    backbone = encoders.CompositionEncMix(H=H)

    if args.readout_mode != "soft":
        return {"H": H, "F": F, "G": G, "backbone": backbone}

    else:
        config_solver, config_model = load_config_dict()

        # soft linear heads to learn importance of different features
        solver = lasso_reg.ParallelRegularizedLasso(
            config_solver, ntasks=len(args.powerset)
        ).to(device)

        # define main model according to
        main_model = lasso_reg.MainModel(
            solver, config_model, ntasks=len(args.powerset), backbone=backbone
        )

        linear_readout_dict = {
            subset: torch.nn.Linear(in_features=len(args.S_k[0]), out_features=1).to(
                device
            )
            for subset in args.powerset
        }
        return {
            "H": H,
            "solver": solver,
            "F": F,
            "G": G,
            "main_model": main_model,
            "backbone": backbone,
            "linear_readout": linear_readout_dict,
        }


def init_or_load_optimizer(models: dict, optimizer_class=torch.optim.Adam, args=None):
    # initialise trainable parameters
    params = []
    if args.shared_encoder:
        params = models["G"][0].parameters()
    else:
        for g_i in models["G"]:
            params = params + list(
                g_i.parameters()
            )  # encoders' parameters are trainable

    """Define Adam optimiser"""
    optimizer = optimizer_class(params, lr=args.lr)
    return params, optimizer


# ---------------- checkpoint and resume training ---------------
# -----------------------------------------------------------------------
def save_models(models: dict, optimizer=None, args=None):
    if not os.path.exists(args.save_dir_G):
        os.makedirs(args.save_dir_G)

    state_dict = {}

    if args.readout_mode == "soft":
        state_dict = {"solver_state_dict": models["solver"].state_dict()}

    for k in range(args.view_k):
        state_dict[f"encoder_{k}_state_dict"] = models["G"][k].state_dict()

    if optimizer is not None:
        state_dict["optimizer_state_dict"] = optimizer.state_dict()

    save_path = os.path.join(args.save_dir_G, f"model.pt")
    torch.save(state_dict, save_path)


def load_models(models, optimizer, args):
    save_path = os.path.join(args.save_dir_G, f"model.pt")
    ckpt = torch.load(save_path)
    if args.readout_mode == "soft":
        models["solver"].load_state_dict(ckpt["solver_state_dict"])
    for k in range(args.view_k):
        models["G"][k].load_state_dict(ckpt[f"encoder_{k}_state_dict"])

    if "optimizer_state_dict" in ckpt:
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
    return models, optimizer


# ------------------------ binary readouts ------------------
# -------------------------------------------------------------
# this only applies to CL, not to classification with soft linear heads
def content_mask(args, hzs: dict):
    if args.readout_mode == "ground_truth":
        return ground_truth_readout(args, hzs)
    elif args.readout_mode == "known_content_size":
        return known_content_size_readout(args, hzs)
    else:
        raise NotImplementedError


# Lemma A.5, ground truth content indices are given
def ground_truth_readout(args, hzs: dict):
    batch_size = hzs[0]["hz"].shape[0]
    mask_dict = {subset: {} for subset in args.powerset}
    for subset in args.powerset:
        for k in subset:
            m = torch.tensor(
                [int(j in list(args.content_dict[subset])) for j in args.S_k[k]]
            )
            mask_dict[subset][k] = torch.stack([m] * batch_size, 0).to(device)
    return mask_dict


# Lemma A.6, content size is given
def known_content_size_readout(args, hzs: dict, content_size_dict: dict = None):
    batch_size = hzs[0]["hz"].shape[0]
    mask_dict = {subset: {} for subset in args.powerset}
    if content_size_dict is None:
        content_size_dict = args.content_size_dict
    for subset in args.powerset:
        for k in subset:
            avg_logits = hzs[k]["hz"].mean(0)[None]
            m = utils.topk_gumble_softmax(
                k=content_size_dict[subset],
                logits=avg_logits,  # hzs[k]["hz"][0][None],
                tau=1.0,
                hard=True,
            )
            mask_dict[subset][k] = torch.concat([m] * batch_size, 0).to(
                device
            )  # batch_size, nSk
    return mask_dict


# ----------------- data generation ----------------------
# -----------------------------------------------------------------
def sample_whole_latent(latent_space, size, device=device):
    z = latent_space.sample_latent(size=size, device=device)  # positive sample
    z3 = latent_space.sample_latent(size=size, device=device)  # negative sample
    return z, z3


def generate_data(
    latent_space, models, num_batches=1, batch_size=4096, loss_func=None, args=None
):
    models["backbone"].eval()

    data_dict = {
        subset: {k: {"c": [], "s": []} for k in subset} for subset in args.powerset
    }

    hz_dict = {
        k: {
            "hz": [],  # unified encoded information
            "hc_mask": {s: [] for s in args.powerset if k in s},  # for all subsets
        }
        for k in range(args.view_k)
    }

    all_z = []

    with torch.no_grad():
        for _ in range(num_batches):
            zs = latent_space.sample_latent(batch_size)  # [batch_size, n_z]
            all_z += [zs]

            hzs = dict({})

            # compute the estimated latents for each view (using the unified encoder)
            for k in range(args.view_k):
                hz = models["backbone"].view_specific_forward(
                    zs, k, args.S_k
                )  # [batch_size, nz]
                hzs[k] = {"hz": hz}  # to compute the readout, preserve ternsor type
                hz_dict[k]["hz"].append(hz.detach().cpu().numpy())

            content_indicators = content_mask(args=args, hzs=hzs)

            for subset_idx, subset in enumerate(args.powerset):
                content_z = zs[:, list(args.content_dict[subset])]
                for k_idx, k in enumerate(subset):
                    style_z = zs[:, list(args.style_dict[subset][k])]
                    z_Sk = zs[:, args.S_k[k]]

                    # append data
                    data_dict[subset][k]["c"].append(content_z.detach().cpu().numpy())
                    data_dict[subset][k]["s"].append(style_z.detach().cpu().numpy())

                    hz_dict[k]["hc_mask"][subset].append(
                        content_indicators[subset][k].detach().cpu().numpy()
                    )

        for subset, subset_dict in data_dict.items():
            for k, k_dict in subset_dict.items():
                data_dict[subset][k]["c"] = np.stack(k_dict["c"], axis=0)
                data_dict[subset][k]["s"] = np.stack(k_dict["s"], axis=0)

        for k, v in hz_dict.items():
            hz_dict[k]["hz"] = np.stack(v["hz"], axis=0)
            for subset in hz_dict[k]["hc_mask"].keys():
                hz_dict[k]["hc_mask"][subset] = np.stack(
                    v["hc_mask"][subset], axis=0
                )  # [num_batches, batch_size, ...]

        return data_dict, hz_dict, torch.stack(all_z, 0).detach().cpu().numpy()


# ------------------ Training -----------------
# --------------------------------------------------------
def train_step(data, loss, models, optimizer, params, args, **kwargs):
    """
    Args:
        data = (z_positive, z_negative),
        loss: loss class from losses.py
        H: {h = g \circ f}_{k} with g being encoder and f predefined mixing function (for each view), shape [K, ],
        optimizer: optimizer object,
        params: parameter to optimize
    """

    models["backbone"].train()
    # reset grad
    if optimizer is not None:
        optimizer.zero_grad()

    z, z3 = data
    z = z.to(device)
    z3 = z3.to(device)

    if args.readout_mode != "soft":  # compute hard loss
        # forward pass
        z_rec, z3_rec, hzs = models["backbone"].forward(
            z=z, z3=z3, S_k=args.S_k, n_views=args.view_k
        )
        content_mask_dict = content_mask(args=args, hzs=hzs)
        total_loss_value, _, _ = loss.loss(content_mask_dict, z_rec, z3_rec)
        accs = -1.0  # no acc in the hard mode
    else:
        total_loss_value = 0.0
        accs = 0.0
        model = models["main_model"]
        model.args = args  # sloppy

        idx = int(len(z) * 0.5)
        xS, xQ = z[:idx], z[idx:]  # 0.5 split

        if not model.warm_start:
            model.solver.reinit()

        with torch.no_grad():
            _, _ = model.inner_loop(
                xS, y=None, niter=model.inner_steps, eval=args.evaluate
            )

        if model.implicit:
            gamma = model.solver.gamma.detach().clone()
            gamma.requires_grad = True
            model.solver.zero_grad()
            for _ in range(model.k_):
                model.solver.gamma.detach().copy_(gamma.detach())
                _, _ = model.inner_loop(xS, y=None, niter=10)
                gamma = (1 - model.lambda_) * gamma + model.lambda_ * model.solver.gamma
            model.solver.gamma.detach().copy_(gamma.detach())
            y_hat, z, yQ = model.outer_loop(xQ, gamma)
        else:
            y_hat, z, yQ = model.outer_loop(xQ, model.solver)
        loss = model.criterion(y_hat.reshape(model.ntasks, -1), yQ)
        total_loss_value = loss
        accs = (
            (
                yQ.reshape(model.ntasks, -1)
                == ((y_hat > 0.5).float().reshape(model.ntasks, -1)).float()
            )
            .float()
            .mean(1)
            .detach()
            .cpu()
            .numpy()
        )

    if optimizer is not None:
        total_loss_value.backward()
        optimizer.step()

    return total_loss_value.item(), accs


def generate_latent_space(args):
    args.n_dependent_dims = args.n_dependent_dims or args.z_n
    assert args.n_dependent_dims <= args.z_n
    latent_spaces_list = []
    Sigma_z_path = os.path.join(args.model_base_dir, "Sigma_z.csv")
    if not args.evaluate:
        if args.independent:
            Sigma_z = np.eye(args.z_n)
        else:
            # In the non-dependent case, we generate a set of dependent and non-dependent latent variables
            Sigma_z = np.eye(args.z_n)
            Sigma_z_dep = wishart.rvs(args.n_dependent_dims, np.eye(args.n_dependent_dims), size=1)
            Sigma_z[:args.n_dependent_dims, :args.n_dependent_dims] = Sigma_z_dep

        np.savetxt(Sigma_z_path, Sigma_z, delimiter=",")
    else:
        Sigma_z = np.loadtxt(Sigma_z_path, delimiter=",")
        print(Sigma_z)
    space = spaces.NRealSpace(args.z_n)
    # Here just one latent space
    sample_latent = lambda space, size, device=device: space.normal(
        None, args.m_param, size, device, Sigma=Sigma_z
    )

    latent_spaces_list.append(
        latent_spaces.LatentSpace(space=space, sample_latent=sample_latent)
    )
    latent_space = latent_spaces.ProductLatentSpace(spaces=latent_spaces_list)
    return latent_space


# --------- Evaluate using regression models and R2 score ------------
# ----------------------------------------------------------------------
def evaluate_prediction(regression_model, metric, X_train, y_train, X_test, y_test):
    # handle edge cases when inputs or labels are zero-dimensional
    if any([0 in x.shape for x in [X_train, y_train, X_test, y_test]]):
        return np.nan
    assert X_train.shape[1] == X_test.shape[1]
    assert y_train.shape[1] == y_test.shape[1]
    # handle edge cases when the inputs are one-dimensional
    if X_train.shape[1] == 1:
        X_train = X_train.reshape(-1, 1)
    regression_model.fit(X_train, y_train)
    y_pred = regression_model.predict(X_test)
    return metric(y_test, y_pred)


# ------------------ Evaluate for hard readout --------------------
# ------------------------------------------------------------------
def evaluate(models, latent_space, args):
    if args.readout_mode != "soft":
        eval_step_hard(models, latent_space, args)
    else:
        eval_step_soft(models, latent_space, args)


def eval_step_soft(models, latent_space, args):
    model = models["main_model"]
    if args.evaluate:
        num_batches = args.num_eval_batches
        file_name = "Evaluation"
    else:
        num_batches = 1
        file_name = "Training"

    eval_losses, eval_accs = [], []

    for _ in range(num_batches):
        z, _ = sample_whole_latent(latent_space, size=args.batch_size)
        z = z.to(device)

        xS, xQ = z[: len(z) // 2], z[len(z) // 2 :]

        model.backbone.eval()
        model.solver.reinit()

        with torch.no_grad():
            _, fmodel = model.inner_loop(
                xS, y=None, niter=500, eval=1
            )  # eval 1 no feat sharing optimized
            logits, z, yQ = model.outer_loop(xQ, model.solver)
            loss = model.criterion(logits.reshape(model.ntasks, -1), yQ)
            accs = (
                (
                    yQ.reshape(model.ntasks, -1)
                    == ((logits > 0.5).float().reshape(model.ntasks, -1)).float()
                )
                .float()
                .mean(1)
                .detach()
                .cpu()
                .numpy()
            )
            eval_losses += [loss.item()]
            eval_accs += [accs]

    eval_losses = np.asanyarray(eval_losses)
    eval_accs = np.asarray(eval_accs)

    fileobj = open(f"{file_name}.csv", "a+")
    writer = csv.writer(fileobj)
    wri = [
        "loss_mean",
        f"{np.mean(eval_losses):.3f}",
        "loss_std",
        f"{np.std(eval_losses):.3f}",
        "acc_mean",
        f"{np.mean(eval_accs):.3f}",
        "acc_std",
        f"{np.std(eval_accs):.3f}",
    ]
    writer.writerow(wri)
    fileobj.close()


def eval_step_hard(
    models, latent_space, args, evaluate_individual=True
):  # TODO: add it to argsparser
    def generate_nonlinear_model():
        if args.mlp_eval:
            model = MLPRegressor(max_iter=5000)  # lightweight option
        else:
            # grid search is time- and memory-intensive
            model = GridSearchCV(
                kernel_ridge.KernelRidge(kernel="rbf", gamma=0.1),
                param_grid={
                    "alpha": [1e0, 0.1, 1e-2, 1e-3],
                    "gamma": np.logspace(-2, 2, 4),
                },
                cv=3,
                n_jobs=-1,
            )
        return model

    if args.evaluate:
        num_batches = args.num_eval_batches
        file_name = "Evaluation"
    else:
        num_batches = 1
        file_name = "Training"

    # lightweight evaluation with linear classifiers
    data_dict, hz_dict, all_zs = generate_data(
        latent_space=latent_space, models=models, num_batches=num_batches, args=args
    )

    # Linear regression from GT contents to styles
    # inputs = all_zs[0, ..., :2]
    # for i in [2, 3, 4, 5]:
    #     labels = all_zs[0, ..., i][:, None]

    #     (
    #         train_inputs,
    #         test_inputs,
    #         train_labels,
    #         test_labels,
    #     ) = train_test_split(inputs, labels)

    #     data = [train_inputs, train_labels, test_inputs, test_labels]

    #     r2_nonlinear = evaluate_prediction(
    #                             generate_nonlinear_model(), r2_score, *data
    #                         )

    #     r2_linear = evaluate_prediction(
    #         linear_model.LinearRegression(n_jobs=-1), r2_score, *data
    #     )
    #     print(f"{i=}. {r2_linear=}, {r2_nonlinear=}")
    # exit(0)

    # standardize the estimated latents hz
    data_shape = hz_dict[0]["hz"].shape  # [num_batches, batch_size, nSk]
    for k, v in hz_dict.items():
        hz_dict[k]["hz"] = (
            StandardScaler()
            .fit_transform(np.concatenate(v["hz"], axis=0))
            .reshape(*data_shape)
        )

    if not args.evaluate_individual_latents:
        for subset_idx, subset in enumerate(data_dict):
            for k in subset:
                scores = {
                    "c_linear": [],
                    "c_nonlinear": [],
                    "s_linear": [],
                    "s_nonlinear": [],
                }
                for i in range(num_batches):
                    predicted_content_idx = np.where(
                        hz_dict[k]["hc_mask"][subset][i].astype(bool)
                    )[-1].reshape(data_shape[1], -1)
                    inputs = np.take_along_axis(
                        hz_dict[k]["hz"][i], predicted_content_idx, axis=-1
                    )
                    for keyword in ["c", "s"]:
                        # labels = StandardScaler().fit_transform(data_dict[subset][k][keyword]) # TODO: double check, do you want to fit transform the labels?
                        labels = data_dict[subset][k][keyword][
                            i
                        ]  # [batch_size, n_keyword]
                        (
                            train_inputs,
                            test_inputs,
                            train_labels,
                            test_labels,
                        ) = train_test_split(inputs, labels)
                        data = [train_inputs, train_labels, test_inputs, test_labels]
                        r2_linear = evaluate_prediction(
                            linear_model.LinearRegression(n_jobs=-1), r2_score, *data
                        )
                        if args.evaluate:
                            # nonlinear regression
                            r2_nonlinear = evaluate_prediction(
                                generate_nonlinear_model(), r2_score, *data
                            )
                        else:
                            r2_nonlinear = -1.0  # not computed
                        # print(f"Subset: {subset}, view: {k}, keyword: {keyword}, r2_linear: {r2_linear}, r2_nonlinear: {r2_nlinear}, avg_mask: {np.mean(hz_dict[k]['hc_mask'][subset], 0)}")
                        scores[f"{keyword}_linear"].append(r2_linear)
                        scores[f"{keyword}_nonlinear"].append(r2_nonlinear)
                for keyword in ["c", "s"]:
                    file_path = os.path.join(
                        args.model_base_dir, f"{file_name}_{keyword}.csv"
                    )
                    fileobj = open(file_path, "a+")
                    writer = csv.writer(fileobj)
                    wri = [
                        subset,
                        "view",
                        k,
                        keyword,
                        "linear mean",
                        f"{np.mean(scores[f'{keyword}_linear']):.3f} +- {np.std(scores[f'{keyword}_linear']) :.3f}",
                        "nonlinear mean",
                        f"{np.mean(scores[f'{keyword}_nonlinear']):.3f} +- {np.std(scores[f'{keyword}_nonlinear']):.3f}",
                        "avg_mask",
                        f"{np.mean(hz_dict[k]['hc_mask'][subset], axis=(0, 1))}",
                    ]
                    writer.writerow(wri)
                    fileobj.close()
    else:
        for subset_idx, subset in enumerate(data_dict):
            scores = {
                latent_idx: {"linear": [], "nonlinear": []}
                for latent_idx in range(args.z_n)
            }
            for k in subset:
                for i in range(num_batches):
                    predicted_content_idx = np.where(
                        hz_dict[k]["hc_mask"][subset][i].astype(bool)
                    )[-1].reshape(data_shape[1], -1)
                    inputs = np.take_along_axis(
                        hz_dict[k]["hz"][i], predicted_content_idx, axis=-1
                    )
                    for latent_idx in range(args.z_n):
                        # labels = StandardScaler().fit_transform(data_dict[subset][k][keyword]) # TODO: double check, do you want to fit transform the labels?
                        labels = all_zs[i, :, latent_idx][
                            :, None
                        ]  # [batch_size, n_keyword]
                        (
                            train_inputs,
                            test_inputs,
                            train_labels,
                            test_labels,
                        ) = train_test_split(inputs, labels)
                        data = [train_inputs, train_labels, test_inputs, test_labels]
                        r2_linear = evaluate_prediction(
                            linear_model.LinearRegression(n_jobs=-1), r2_score, *data
                        )
                        if args.evaluate:
                            # nonlinear regression
                            r2_nonlinear = evaluate_prediction(
                                generate_nonlinear_model(), r2_score, *data
                            )
                        else:
                            r2_nonlinear = -1.0  # not computed
                        # print(f"Subset: {subset}, view: {k}, keyword: {keyword}, r2_linear: {r2_linear}, r2_nonlinear: {r2_nlinear}, avg_mask: {np.mean(hz_dict[k]['hc_mask'][subset], 0)}")
                        scores[latent_idx]["linear"].append(r2_linear)
                        scores[latent_idx]["nonlinear"].append(r2_nonlinear)
                for latent_idx in range(args.z_n):
                    file_path = os.path.join(args.model_base_dir, f"{file_name}.csv")
                    fileobj = open(file_path, "a+")
                    writer = csv.writer(fileobj)
                    wri = [
                        subset,
                        "view",
                        k,
                        "latent_idx",
                        latent_idx,
                        "linear mean",
                        f"{np.mean(scores[latent_idx]['linear']):.3f} +- {np.std(scores[latent_idx]['linear']) :.3f}",
                        "nonlinear mean",
                        f"{np.mean(scores[latent_idx]['nonlinear']):.3f} +- {np.std(scores[latent_idx]['nonlinear']):.3f}",
                        "avg_mask",
                        f"{np.mean(hz_dict[k]['hc_mask'][subset], axis=(0, 1))}",
                    ]
                    writer.writerow(wri)
                    fileobj.close()


# ---------------------------- Parser --------------------------
# ---------------------------------------------------------------
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-n", type=int, default=0)
    parser.add_argument("--z-n", type=int, default=6)
    parser.add_argument("--evaluate", action="store_true")  # by default false
    parser.add_argument("--model_base_dir", type=str, default="res_numerical")
    parser.add_argument("--model-dir-G", type=str, default="models-G")
    parser.add_argument("--model-dir-F", type=str, default="models-F")
    parser.add_argument("--num-train-batches", type=int, default=5)
    parser.add_argument("--save-dir-G", type=str, default="")
    parser.add_argument("--save-dir-F", type=str, default="")
    parser.add_argument("--num-eval-batches", type=int, default=10)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--c-param", type=float, default=1.0)
    parser.add_argument("--m-param", type=float, default=1.0)
    parser.add_argument("--n-mixing-layer", type=int, default=3)
    parser.add_argument(
        "--lr", type=float, default=1e-4
    )  # 1e-4 for hard; 1e-1 for soft
    parser.add_argument("--no-cuda", action="store_true")
    parser.add_argument("--batch-size", type=int, default=4096)  # 6144
    parser.add_argument("--n-log-steps", type=int, default=100)  # 250
    parser.add_argument("--n-steps", type=int, default=10001)  # 100001
    parser.add_argument("--resume-training", action="store_false")
    parser.add_argument("--view-k", type=int, default=4)  # number of views we consider
    parser.add_argument(
        "--S-k",
        type=int,
        default=[[0, 1, 2, 3, 4], [0, 1, 2, 4, 5], [0, 1, 2, 3, 5], [0, 1, 3, 4, 5]],
    )
    parser.add_argument("--load-F", default=None)
    parser.add_argument("--load-G", default=None)
    parser.add_argument(
        "--only-consider-whole-set", action="store_true"
    )  # take default as false
    parser.add_argument("--mlp-eval", type=bool, default=False)
    parser.add_argument("--shared-mixing-function", type=bool, default=False)
    parser.add_argument("--shared-encoder", type=bool, default=False)
    # parser.add_argument("--reg-coefficient", type=float, default=0.)
    parser.add_argument(
        "--readout-mode",
        type=str,
        default="ground_truth",
        choices=["ground_truth", "known_content_size", "soft"],
    )
    parser.add_argument("--independent", action="store_true")
    parser.add_argument("--evaluate_individual_latents", action="store_true")
    parser.add_argument("--n_dependent_dims", default=3, type=int)
    args = parser.parse_args()
    return args, parser


# ------------------- main loop ------------------------------------
# ------------------------------------------
def main():
    args, parser = parse_args()
    os.makedirs(args.model_base_dir, exist_ok=True)
    with open(os.path.join(args.model_base_dir, "settings.json"), "w") as fp:
        json.dump(args.__dict__, fp, ensure_ascii=False)
    args = update_args(args)  # update powersets and information
    


    if not args.evaluate:
        args.save_dir_G = os.path.join(args.model_base_dir, args.model_dir_G)
        args.save_dir_F = os.path.join(args.model_base_dir, args.model_dir_F)
    else:
        args.load_G = os.path.join(args.model_base_dir, args.model_dir_G)
        args.load_F = os.path.join(args.model_base_dir, args.model_dir_F)
        args.n_steps = 1

    print("Arguments:")
    for k, v in vars(args).items():
        print(f"\t{k}: {v}")

    global device
    if args.no_cuda:
        device = "cpu"
        print("Using cpu")

    # set seed
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)

    """Initialisation"""
    F = init_or_load_mixing_functions(device, args)  # mixing function always gives S_k
    G = init_or_load_encoder_models(device, args)
    models = init_or_load_training_models(F=F, G=G, device=device, args=args)
    params, optimizer = init_or_load_optimizer(models=models, args=args)

    # initialise loss function
    loss = losses.UnifiedEncoderLoss(losses.LpSimCLRLoss())
    # initialise latent space
    latent_space = generate_latent_space(args)
    """"""

    if args.save_dir_F:
        if not os.path.exists(args.save_dir_F):
            os.makedirs(args.save_dir_F)
        for idx, model in enumerate(F):
            save_path = os.path.join(args.save_dir_F, f"f_{idx}.pth")
            torch.save(model.state_dict(), save_path)

    # ----------Training
    # --------------------------------------------
    if (
        "total_loss_values" in locals() and not args.resume_training
    ) or "total_loss_values" not in locals():
        total_loss_values = []
        accs_global = []

    global_step = len(total_loss_values) + 1
    last_save_at_step = 0
    while global_step <= args.n_steps and not args.evaluate:
        data = sample_whole_latent(latent_space=latent_space, size=args.batch_size)
        total_loss_value, accs = train_step(
            data=data,
            loss=loss,
            models=models,
            optimizer=optimizer,
            params=params,
            args=args,
        )

        # store losses
        total_loss_values.append(total_loss_value)
        accs_global.append(accs)
        # individual_losses_values.append(losses_value)

        # checkpoint & evaluate for every n_log_steps
        if global_step % args.n_log_steps == 1 or global_step == args.n_steps:
            save_models(
                models, optimizer, args
            )  # add step_idx for the models, otherwise will be overwrite
            evaluate(models, latent_space, args)
            print(
                f"Step: {global_step} \t",
                f"Loss: {total_loss_value:.4f} \t",
                f"<Loss>: {np.mean(np.array(total_loss_values[-args.n_log_steps:])):.4f} \t",
                f"<Acc>: {np.mean(np.array(accs_global[-args.n_log_steps:])):.4f} \t",
            )
        global_step += 1

    # ----- Evaluation
    # --------------------------------------
    if args.evaluate:
        evaluate(models, latent_space, args)


if __name__ == "__main__":
    main()
