import os
import time
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt

from functools import partial
from tqdm import tqdm
from loguru import logger
from omegaconf import OmegaConf

from utils.model_utils import *
from utils.utils import *
from utils.dataset_loader import (
    SubgraphIsomorphismDataset,
    GraphEditDistanceDataset,
    UneqGraphEditDistanceDataset,
)
from utils.training_utils import (
    pairwise_ranking_loss,
    evaluate_model_faster,
)

from .distn_expt import *

import torch.multiprocessing as mp
mp.set_start_method("spawn", force=True)

### FUNCTIONS FOR TRAINING AND PLOTTING THE Exchangeability expt in Appendix
### Associated config = expt3.yaml

###### helper
def train_model_for_v2(model_num,conf,gmn_config,dataset,model_class,arg):
    
    non_exchbl, n_batches, num_sub_batches= arg
    run = 0
    conf.training.seed = model_num
    # use the new seed for model initialization
    set_seed(conf.training.seed)
    model = model_class(conf, gmn_config)
    
    if non_exchbl:
        if conf.model.non_exchbl == "strat1":
            make_nex_1(model.encoder)
        elif conf.model.non_exchbl == "strat2":
            make_nex_2(model.encoder)
        else:
            raise ValueError("Unknown non-exchangeable strategy")
    model.train()
    # Initialize models and optimizers

    # es = EarlyStoppingModule(conf.base_dir, conf.task.name, patience=conf.training.patience, logger=logger)
    opt = torch.optim.Adam(
        model.parameters(), 
        lr=conf.training.learning_rate,
        weight_decay=conf.training.weight_decay # do we need it for such few epochs?
        )
    # opts[model_num].share_memory()
    
    ## 
    # init_embedding = get_embeds(model,dataset)
    init_state_dict = model.state_dict()
    with open(f"{conf.write_dir}/data/{conf.task.name}/init_state_dict_{model_num}.pkl","wb") as f:
        pickle.dump(init_state_dict,f)
    
    run=0
    while run < conf.expt.num_epochs:
    
        epoch_loss = 0
        training_start_time = time.time()
        for batch_idx in range(n_batches):
            # logger.debug(f"Fetching for model {model_num} idx {batch_idx}")
            # batch_graphs, batch_graph_node_sizes, batch_graph_edge_sizes, labels = dataset.memo_fetch_batch_by_id(batch_idx)
            batch_graphs, batch_graph_node_sizes, batch_graph_edge_sizes, labels = dataset.fetch_batch_by_id(batch_idx)
            
            opt.zero_grad()
            prediction = model(batch_graphs, batch_graph_node_sizes, batch_graph_edge_sizes) #, batch_adj_matrices)

            predictions_for_positives = prediction[labels > 0.5]
            predictions_for_negatives = prediction[labels < 0.5]
            loss = pairwise_ranking_loss(
                predictions_for_positives.unsqueeze(1),
                predictions_for_negatives.unsqueeze(1),
                conf.training.margin
            )
            loss.backward()
            opt.step()
        
            epoch_loss += loss.item()
            
            if (batch_idx+1)%num_sub_batches==0 or (batch_idx+1)==n_batches:
                epoch_training_time = time.time() - training_start_time
                # logger.info(f"Run: {run}  Time: {epoch_training_time}")  
                # logger.debug(f"Losses: {epoch_loss}")  

                epoch_loss = 0
                training_start_time = time.time()
                run += 1
    with open(f"{conf.write_dir}/data/{conf.task.name}/state_dict_{model_num}_final.pkl","wb") as f:
        pickle.dump(model.state_dict(),f)
    return #f"Model {model_num} done"

## sequential trainer
def train_many_v0(conf,gmn_config):
    num_procs = conf.expt.num_procs
    if conf.dataset.rel_mode == "sub_iso":
        dataset = SubgraphIsomorphismDataset(conf,mode="train")
        # val_dataset = SubgraphIsomorphismDataset(conf,mode="val")
    elif conf.dataset.rel_mode == "ged":
        dataset = GraphEditDistanceDataset(conf,mode="train") 
        # val_dataset = GraphEditDistanceDataset(conf,mode="val")
    elif conf.dataset.rel_mode == "uneq_ged":
        dataset = UneqGraphEditDistanceDataset(conf,mode="train")
        # val_dataset = UneqGraphEditDistanceDataset(conf,mode="val")
    conf.actual_max_node_set_size = dataset.max_node_set_size
    conf.actual_max_edge_set_size = dataset.max_edge_set_size
    
    model_class = get_class(f"{conf.model.classPath}.{conf.model.name}")
    
    non_exchbl:bool = False
    ### New: add non-exchangeable strategy
    try:
        non_exchbl = (conf.model.non_exchbl!=None)
    except:
        pass
    if non_exchbl:
        conf.task.name = f"{conf.task.name}_{conf.model.non_exchbl}"
    
    
    all_models_to_train = []
    for i in range(conf.expt.start_num,conf.expt.start_num + conf.expt.num_models):
        if not os.path.exists(f"{conf.write_dir}/data/{conf.task.name}/state_dict_{i}_final.pkl"):
            all_models_to_train.append(i)
          
    # else:
        # status = {"model":0, "losses":{}}
    
    os.makedirs(f"{conf.write_dir}/data/{conf.task.name}",exist_ok=True)
    
    def split(lst, n):
        for i in range(0, len(lst), n):
            yield lst[i:i + n]
    
    for models_to_train in split(all_models_to_train,num_procs):
        
        # raise NotImplementedError(all_models_to_train)
        models = {}
        opts = {}
        
        for model_num in models_to_train:
            # initialize model and optimiser
            conf.training.seed = model_num
            # use the new seed for model initialization
            set_seed(conf.training.seed)
            models[model_num] = model_class(conf, gmn_config)
            models[model_num].train()    
            # es = EarlyStoppingModule(conf.base_dir, conf.task.name, patience=conf.training.patience, logger=logger)
            opts[model_num] = torch.optim.Adam(
                models[model_num].parameters(), 
                lr=conf.training.learning_rate,
                weight_decay=conf.training.weight_decay # do we need it for such few epochs?
                )
            
            ## 
            # init_embedding = get_embeds(model,dataset)
            init_state_dict = models[model_num].state_dict()
            if non_exchbl:
                if conf.model.non_exchbl == "strat1":
                    make_nex_1(models[model_num].encoder)
                elif conf.model.non_exchbl == "strat2":
                    make_nex_2(models[model_num].encoder)
                # elif conf.model.non_exchbl == "strat3":
                #     make_nex_3(models[model_num].encoder)
                else:
                    raise ValueError("Unknown non-exchangeable strategy")
            with open(f"{conf.write_dir}/data/{conf.task.name}/init_state_dict_{model_num}.pkl","wb") as f:
                pickle.dump(init_state_dict,f)
            
        
        run = 0
        n_batches = dataset.create_stratified_batches()
        # save initial embedding, model dict
        while run < conf.expt.num_epochs:

            sub_factor = conf.training.patience//2   
            num_sub_batches = n_batches//sub_factor
            epoch_losses = {k:0 for k in models_to_train}
            training_start_time = time.time()
            for batch_idx in tqdm(range(n_batches)):
                batch_graphs, batch_graph_node_sizes, batch_graph_edge_sizes, labels = dataset.fetch_batch_by_id(batch_idx)

                for model_num in models_to_train:
                    opts[model_num].zero_grad()
                    prediction = models[model_num](batch_graphs, batch_graph_node_sizes, batch_graph_edge_sizes) #, batch_adj_matrices)

                    predictions_for_positives = prediction[labels > 0.5]
                    predictions_for_negatives = prediction[labels < 0.5]
                    losses = pairwise_ranking_loss(
                        predictions_for_positives.unsqueeze(1),
                        predictions_for_negatives.unsqueeze(1),
                        conf.training.margin
                    )
                    losses.backward()
                    opts[model_num].step()
                    epoch_losses[model_num] += losses.item()
                
                if (batch_idx+1)%num_sub_batches==0 or (batch_idx+1)==n_batches:
                    epoch_training_time = time.time() - training_start_time
                    logger.info(f"Run: {run}  Time: {epoch_training_time}")    

                    for k,v in epoch_losses.items():
                        epoch_losses[k] = 0 
                    training_start_time = time.time()
                    run += 1
                
        for model_num in models_to_train:
            with open(f"{conf.write_dir}/data/{conf.task.name}/state_dict_{model_num}_final.pkl","wb") as f:
                pickle.dump(models[model_num].state_dict(),f)

## parallelised trainer
def train_many_v2(conf,gmn_config):
    num_procs = conf.expt.num_procs
    
    if conf.dataset.rel_mode == "sub_iso":
        dataset = SubgraphIsomorphismDataset(conf,mode="train")
    elif conf.dataset.rel_mode == "ged":
        dataset = GraphEditDistanceDataset(conf,mode="train") 
    elif conf.dataset.rel_mode == "uneq_ged":
        dataset = UneqGraphEditDistanceDataset(conf,mode="train")
    
    conf.actual_max_node_set_size = dataset.max_node_set_size
    conf.actual_max_edge_set_size = dataset.max_edge_set_size
    
    model_class = get_class(f"{conf.model.classPath}.{conf.model.name}")
    
    non_exchbl:bool = False
    ### New: add non-exchangeable strategy
    try:
        non_exchbl = (conf.model.non_exchbl!=None)
    except:
        pass
    if non_exchbl:
        conf.task.name = f"{conf.task.name}_{conf.model.non_exchbl}"
    
    
    all_models_to_train = []
    for i in range(conf.expt.start_num,conf.expt.start_num + conf.expt.num_models):
        if not os.path.exists(f"{conf.write_dir}/data/{conf.task.name}/state_dict_{i}_final.pkl"):
            all_models_to_train.append(i)
    
    os.makedirs(f"{conf.write_dir}/data/{conf.task.name}",exist_ok=True)
    
    n_batches = dataset.create_stratified_batches()
    sub_factor = conf.training.patience//2   
    num_sub_batches = n_batches//sub_factor
    

    train_model_for_v2_ = partial(train_model_for_v2,conf=conf,gmn_config=gmn_config,dataset=dataset,model_class=model_class,arg=(non_exchbl,n_batches,num_sub_batches))
    with mp.Pool(processes=num_procs ) as pool:
        result = pool.map(train_model_for_v2_,all_models_to_train)
        
### plotting function ####    

def exchbl_plots(conf):
    
    dims = tuple(range(-1,conf.emb_dim))
    
    def sorted_diff(q_emb,c_emb):
        return (torch.sort(q_emb,dim=1)[0] - torch.sort(c_emb,dim=1)[0])
    
    def subiso_sort_func(q_emb,c_emb):
        return torch.sum(torch.nn.functional.relu(sorted_diff(q_emb,c_emb)),dim=1)
    def ged_sort_func(q_emb,c_emb):
        return torch.sum(torch.abs(sorted_diff(q_emb,c_emb)), dim=1)
    def uneq_ged_sort_func(q_emb,c_emb):
        return torch.sum(2*torch.nn.functional.relu(sorted_diff(q_emb,c_emb)) + torch.nn.functional.relu(-sorted_diff(q_emb,c_emb)), dim=1)
    
    if conf.dataset.rel_mode == "sub_iso":
        dataset = SubgraphIsomorphismDataset(conf,mode="train")
    elif conf.dataset.rel_mode == "ged":
        dataset = GraphEditDistanceDataset(conf,mode="train") 
    elif conf.dataset.rel_mode == "uneq_ged":
        dataset = UneqGraphEditDistanceDataset(conf,mode="train")
    

        
    conf.actual_max_node_set_size = dataset.max_node_set_size
    conf.actual_max_edge_set_size = dataset.max_edge_set_size
    
    model_class = get_class(f"{conf.model.classPath}.{conf.model.name}")
    
    model  = model_class(conf, gmn_config)
    model.eval()
    
    
    num_points = 5_000 # 1_000

    corp_ids = torch.randint(0, dataset.num_corpus_graphs, (num_points,))
    quer_ids = torch.randint(0, dataset.num_query_graphs, (num_points,))
    corp_uniq,corp_inv = torch.unique(corp_ids,return_inverse=True,sorted=True)
    quer_uniq,quer_inv = torch.unique(quer_ids,return_inverse=True,sorted=True)
    
    # pair_ids = torch.stack([corp_ids,quer_ids],dim=1)
    
    logger.info("uniq:",len(corp_uniq),len(quer_ids))
    
    corpus_data = tuple([dataset._pack_batch_1d([dataset.corpus_graphs[x] for x in corp_uniq]), dataset.corpus_graph_node_sizes[corp_uniq], dataset.corpus_graph_edge_sizes[corp_uniq]])
    query_data = tuple([dataset._pack_batch_1d([dataset.query_graphs[x] for x in quer_uniq]), dataset.query_graph_node_sizes[quer_uniq], dataset.query_graph_edge_sizes[quer_uniq]])

    gt  = {"sub_iso":[],"ged":[],"uneq_ged":[]}
    sort = {"sub_iso":[],"ged":[],"uneq_ged":[]}
    
    # non_exchbls = ("strat1","strat2",None)
    non_exchbls = ("strat2",None)
    titles = {"strat2" : r"\textbf{non-iid initialization}",None: r"\textbf{iid initialization}"}
    colors = {"strat2" : "blue",None: "red"}
    
    for non_exchbl in non_exchbls:
        conf.model.non_exchbl = non_exchbl
        conf = set_task(conf)
        model_path = f"{conf.write_dir}/data/{conf.task.name}"
        logger.info(f"Model path: {model_path}")
        # init_file_name = lambda n: f"init_state_dict_{n}.pkl"
        file_name = lambda n: f"state_dict_{n}_final.pkl"
        num_models = 0
        for i in tqdm(range(conf.expt.num_models)):
            if os.path.exists(f"{model_path}/{file_name(i)}"):
                num_models +=1
                    
                for rel_mode,gt_func,sort_func in zip ( 
                                                    ["sub_iso","ged","uneq_ged"],
                                                    [subiso_feature_alignment_distance,ged_feature_alignment_distance,uneq_ged_feature_alignment_distance],
                                                    [subiso_sort_func,ged_sort_func,uneq_ged_sort_func]
                                                    ):
                    
                    if rel_mode != conf.dataset.rel_mode:
                        continue
                    gt[rel_mode].append([])
                    sort[rel_mode].append([])
                    # conf.dataset.rel_mode = rel_mode
                    # conf = set_task(conf)
                
                    model_path = f"{conf.write_dir}/data/{conf.task.name}"
                    file_name = lambda n: f"state_dict_{n}_final.pkl"
                    # model = models[rel_mode]
                    model.load_state_dict(pickle.load(open(f"{model_path}/{file_name(i)}","rb")))
                    model.eval()
                    model.fetch_embed = True
                    
                    stacked_node_features_query = model(*query_data)
                    stacked_node_features_corpus = model(*corpus_data)
                    
                    # Computation of node transport plan
                    transformed_features_query = model.node_sinkhorn_feature_layers(stacked_node_features_query)
                    transformed_features_corpus = model.node_sinkhorn_feature_layers(stacked_node_features_corpus)

                    def mask_graphs(features, graph_sizes):
                        mask = torch.stack([model.graph_size_to_mask_map[i] for i in graph_sizes])
                        return mask * features
                    masked_features_query = mask_graphs(transformed_features_query, query_data[1])
                    masked_features_corpus = mask_graphs(transformed_features_corpus, corpus_data[1])

                    node_sinkhorn_input = torch.matmul(masked_features_query[quer_inv], masked_features_corpus[corp_inv].permute(0, 2, 1))
                    node_transport_plan = pytorch_sinkhorn_iters(
                        log_alpha=node_sinkhorn_input, device=model.device,temperature=model.sinkhorn_temp 
                    )
                    gt[rel_mode][-1].append(gt_func(stacked_node_features_query[quer_inv],stacked_node_features_corpus[corp_inv],node_transport_plan)/stacked_node_features_corpus.shape[2])
                    sort[rel_mode][-1].append(sort_func(stacked_node_features_query[quer_inv],stacked_node_features_corpus[corp_inv])/stacked_node_features_corpus.shape[2])
                    
        logger.info(f"Loaded {num_models} models")
    
    for rel_mode in ["sub_iso","ged","uneq_ged"]:
        if rel_mode != conf.dataset.rel_mode:
            continue

        os.makedirs(f"{conf.write_dir}/plots/{conf.task.name}", exist_ok=True)

        gtvs_all = [torch.stack(v).squeeze().cpu().detach().numpy()/conf.dataset.embed_dim for v in gt[rel_mode]]
        xvs_all = [torch.stack(v).squeeze().cpu().detach().numpy() for v in sort[rel_mode]]
        logger.info(f"gtvs: {len(gtvs_all)} xvs: {len(xvs_all)}")
        logger.info(f"gtvs: {gtvs_all[0].shape, gtvs_all[1].shape} xvs: {xvs_all[0].shape, xvs_all[1].shape}")


        if conf.plt.dev:
            # only legend consisting of two columns, non-iid initialisation and iid initialisation in blue and red text respectively. figsize 1x6
            import matplotlib.patches as mpatches
            fig = plt.figure(figsize=(6, 1))
            handles = [
                mpatches.Patch(color=colors["strat2"], label=titles["strat2"]),
                mpatches.Patch(color=colors[None], label=titles[None])
            ]
            # Create a dummy axis just for the legend, then remove it
            ax = fig.add_subplot(111)
            ax.axis('off')
            legend = fig.legend(handles=handles, loc='center', frameon=False, ncol=2, fontsize=24)
            fig.patch.set_alpha(0.0)
            plt.tight_layout()
            plt.savefig(f"{conf.write_dir}/plots/legend_dev.pdf", bbox_inches='tight', transparent=True)
            plt.close()
            return
        for dim in dims:

            gtvs = gtvs_all 
            xvs = [e[:,dim] if dim != -1 else e.mean(axis=-1) for e in xvs_all]
            

            title_base = f"{rel_mode} - {'Mean' if dim == -1 else 'Dim ' + str(dim)}"

            xy_min = min(min(gtv.min() for gtv in gtvs), min(xv.min() for xv in xvs))
            xy_max = max(max(gtv.max() for gtv in gtvs), max(xv.max() for xv in xvs))
            

            # ---------------- SEABORN SCATTERPLOT ----------------
            if conf.plt.scatter:
                plt.figure(figsize=(9, 6))
                for gtv, xv, non_exchbl in zip(gtvs, xvs, non_exchbls):
                    sns.scatterplot(x=gtv, y=xv, alpha=0.5, label=titles[non_exchbl], color=colors[non_exchbl])
                plt.plot([xy_min, xy_max], [xy_min, xy_max], 'r--', label='x = y')
                plt.xlabel(r"$\frac{1}{D} \Delta(G_c,G_q)$")
                plt.ylabel(r"$\Delta_d(G_c,G_q)$" if dim != -1 else r"$\frac{1}{D} \sum_{d=1}^D \Delta_d(G_c,G_q)$")
                plt.title(f"Scatterplot (Seaborn) for {title_base}")
                plt.legend()
                plt.legend(prop={'weight': 'bold'})
                plt.tight_layout()
                plt.savefig(f"{conf.write_dir}/plots/{conf.task.name}/{rel_mode}_avg{dim==-1}{'' if dim==-1 else dim}_scatterplot_seaborn.pdf")
                plt.close()

            if conf.plt.dev:
                # ---------------- SEABORN DEVIATION HISTOGRAM ----------------
                plt.figure(figsize=(9, 6))
                peaks = {}
                drange = {}
                for j, (gtv, xv, non_exchbl) in enumerate(zip(gtvs, xvs, non_exchbls)):
                    hist = sns.histplot(gtv - xv, kde=True, stat="density", color=colors[non_exchbl], alpha=0.3, bins=50, edgecolor=None)
                    kde_line = hist.lines[-1]  # Always get the last line plotted (KDE)
                    x_data, y_data = kde_line.get_data()
                    max_idx = np.argmax(y_data)
                    peak_x, peak_y = x_data[max_idx], y_data[max_idx]
                    drange[non_exchbl] = (x_data.max() - x_data.min(), y_data.max() - y_data.min())
                    peaks[non_exchbl] = (peak_x, peak_y)

                plt.vlines(0, ymin=0, ymax=1.05 * max(peak[1] for peak in peaks.values()), colors='gray', linestyles='dashed', lw=2)


                plt.xlabel(
                    r"$\mathrm{sim}_d(G_c,G_q) - \frac{1}{D} \mathrm{sim}(G_c,G_q)$" if dim != -1 
                    else r"$\frac{1}{D} \sum_{d=1}^D \mathrm{sim}_d(G_c,G_q) - \frac{1}{D} \mathrm{sim}(G_c,G_q) \longrightarrow$",
                    fontsize=30 if dim != -1 else 26
                )
                plt.ylabel(r"\textbf{Density} $\longrightarrow$", fontsize=26)

                plt.tight_layout()
                plt.savefig(f"{conf.write_dir}/plots/{conf.task.name}/{rel_mode}_avg{dim==-1}{'' if dim==-1 else dim}_dev_hist_seaborn.pdf")
                plt.close()


        print(f"Written all plots to {conf.write_dir}/plots/{conf.task.name}/")
    
### UTILS ###

def set_task(conf):
    if conf.dataset.rel_mode == "sub_iso":
        conf.task.name = f"distn_{conf.model.name}_{conf.dataset.rel_mode}_{conf.dataset.name}_{task_name}_stemp={conf.training.sinkhorn_temp}_margin={conf.training.margin}"
    elif conf.dataset.rel_mode == "ged":
        conf.task.name = f"distn_{conf.model.name}_{conf.dataset.rel_mode}_{conf.dataset.name}_{task_name}_stemp={conf.training.sinkhorn_temp}_margin={conf.training.margin}"
    elif conf.dataset.rel_mode == "uneq_ged":
        conf.task.name = f"distn_Uneq_{conf.model.name}_{conf.dataset.name}_numC={conf.dataset.aug_num_cgraphs}_MaxSkew={conf.dataset.MaxSkew}_{task_name},stemp={conf.training.sinkhorn_temp},margin={conf.training.margin}"
    else:
        raise ValueError("Unknown rel_mode")
    non_exchbl:bool = False
    ### New: add non-exchangeable strategy
    try:
        non_exchbl = (conf.model.non_exchbl!=None)
    except:
        pass
    if non_exchbl:
        conf.task.name = f"{conf.task.name}_{conf.model.non_exchbl}"
    return conf

# replace the weights with non iid initialization
# original kaiming init is [-a,a] everywhere
# in this case , we will make the first half of the weights to be in [-a,0] and the second half to be in [0,a]
def make_nex_1(encoder):
    with torch.no_grad():
        for layer in encoder.MLP1:
            if hasattr(layer, 'weight'):
                shape = layer.weight.shape
                mask = torch.ones_like(layer.weight)
                mask[:shape[0]//2,:] = -1
                new_weights = (layer.weight + torch.max(layer.weight)*mask)/2
                layer.weight.copy_(new_weights)
## strategy 2 - initialization trick: sort the initalised parameters by rows                
def make_nex_2(encoder):
    with torch.no_grad():
        for layer in encoder.MLP1:
            if hasattr(layer, 'weight'):
                weights, _ = torch.sort(layer.weight,dim=0)
                layer.weight.copy_(weights)


if __name__=="__main__":
    expt_config = OmegaConf.load(f"configs/expt3.yaml")
    
    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.rel_mode}/{cli_conf.dataset.name}.yaml")
    model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    
    latexify()
    
    try:
        print(model_conf.model.non_exchbl is None)
    except:
        model_conf = OmegaConf.merge(model_conf, OmegaConf.create({"model": {"non_exchbl": None}}))
        
    
    
    conf = OmegaConf.merge(main_conf, data_conf, model_conf, expt_config, cli_conf)

    name_removal_set = {'classPath', 'name'}
    task_name = ",".join("=".format(i, k ) for i, k  in conf.model.items() if (i not in name_removal_set))

    
    if conf.dataset.rel_mode == "sub_iso":
        conf.task.name = f"distn_{conf.model.name}_{conf.dataset.rel_mode}_{conf.dataset.name}_{task_name}_stemp={conf.training.sinkhorn_temp}_margin={conf.training.margin}"
    elif conf.dataset.rel_mode == "ged":
        conf.task.name = f"distn_{conf.model.name}_{conf.dataset.rel_mode}_{conf.dataset.name}_{task_name}_stemp={conf.training.sinkhorn_temp}_margin={conf.training.margin}"
    elif conf.dataset.rel_mode == "uneq_ged":
        conf.task.name = f"distn_Uneq_{conf.model.name}_{conf.dataset.name}_numC={conf.dataset.aug_num_cgraphs}_MaxSkew={conf.dataset.MaxSkew}_{task_name},stemp={conf.training.sinkhorn_temp},margin={conf.training.margin}"
    else:
        raise ValueError("Unknown rel_mode")
    logger.info(f"Task name: {conf.task.name}")
    logger.add(f"{conf.log.dir}/{conf.task.name}.log")
    logger.info(OmegaConf.to_yaml(conf))

    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    

    gmn_config = modify_gmn_main_config(get_default_gmn_config(conf), conf, logger)
    
    ## generate data
    if conf.expt.train:
        try:
            if conf.expt.version==0:
                train_many_v0(conf,gmn_config)
            elif conf.expt.version==2:
                train_many_v2(conf,gmn_config)
            else:
                print("No Data Generated: Unknown version")
        except Exception as e:
            print("No Data Generated: ",e)
    
    ## plot data
    with torch.no_grad():
        if conf.expt.plot:
            exchbl_plots(conf)
            print("Plots 3 done")
    
    
    
