import os.path as osp
import sys, os
import shutil
import numpy as np
import argparse
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.nn import Linear, Sequential, ReLU, BatchNorm1d as BN
from torch_scatter import scatter_mean, scatter_max
from torch_geometric.datasets import TUDataset
from torch_geometric.utils import degree
from torch_geometric.nn import GCNConv, GINConv, global_add_pool
import torch_geometric.transforms as T
from torch_geometric.datasets import GNNBenchmarkDataset
# from k_gnn import GraphConv, max_pool
# from k_gnn import TwoMalkin, ConnectedThreeMalkin
from dataloader import DataLoader  # use a custom dataloader to handle subgraphs
# from utils import create_subgraphs
import pdb
from torch_geometric.data import Data
from collections import defaultdict
import igraph
from random import shuffle



"Transform Funcions"

def resistance_distance(data):
    """resistance distance.See "Link prediction in complex networks: A survey".
    Adapted from NestedGNN:https://github.com/muhanzhang/NestedGNN
    Args:
        data(PyG.Data):pyg data object
    """

    edge_index = data.edge_index
    num_nodes = data.num_nodes
    adj = to_scipy_sparse_matrix(
        edge_index, num_nodes=num_nodes
    ).tocsr()
    laplacian = ssp.csgraph.laplacian(adj).toarray()
    try:
        L_inv = linalg.pinv(laplacian)
    except:
        laplacian += 0.01 * np.eye(*laplacian.shape)
    lxx = L_inv[0, 0]
    lyy = L_inv[list(range(len(L_inv))), list(range(len(L_inv)))]
    lxy = L_inv[0, :]
    lyx = L_inv[:, 0]
    rd_to_x = torch.FloatTensor((lxx + lyy - lxy - lyx)).unsqueeze(1)
    data.rd = rd_to_x
    return data


def post_transform(wo_path_encoding, wo_edge_feature):
    """Post transformation of dataset for KP-GNN
    Args:
        wo_path_encoding (bool): If true, remove path encoding from model
        wo_edge_feature (bool): If true, remove edge feature from model
    """
    if wo_path_encoding and wo_edge_feature:
        def transform(g):
            edge_attr = g.edge_attr
            edge_attr[edge_attr > 2] = 2
            g.edge_attr = edge_attr
            if "pe_attr" in g:
                pe_attr = g.pe_attr
                pe_attr[pe_attr > 0] = 0
                g.pe_attr = pe_attr
            return g
    elif wo_edge_feature:
        def transform(g):
            edge_attr = g.edge_attr
            t = edge_attr[:, 0]
            t[t > 2] = 2
            edge_attr[:, 0] = t
            g.edge_attr = edge_attr
            return g

    elif wo_path_encoding:
        def transform(g):
            edge_attr = g.edge_attr
            t = edge_attr[:, 1:]
            t[t > 2] = 2
            edge_attr[:, 1:] = t
            g.edge_attr = edge_attr
            if "pe_attr" in g:
                pe_attr = g.pe_attr
                pe_attr[pe_attr > 0] = 0
                g.pe_attr = pe_attr
            return g
    else:
        def transform(g):
            return g

    return transform


"subgraph canonicalization functions"


def k_hop_subgraph(node_idx, num_hops, edge_index, relabel_nodes=False,
                   num_nodes=None, flow='source_to_target', node_label='hop',
                   max_nodes_per_hop=None):
    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    assert flow in ['source_to_target', 'target_to_source']
    if flow == 'target_to_source':
        row, col = edge_index
    else:
        col, row = edge_index

    node_mask = row.new_empty(num_nodes, dtype=torch.bool)
    edge_mask = row.new_empty(row.size(0), dtype=torch.bool)

    subsets = [torch.tensor([node_idx], device=row.device).flatten()]
    visited = set(subsets[-1].tolist())
    label = defaultdict(list)
    for node in subsets[-1].tolist():
        label[node].append(1)
    if node_label == 'hop':
        hops = [torch.LongTensor([0], device=row.device).flatten()]
    for h in range(num_hops):
        node_mask.fill_(False)
        node_mask[subsets[-1]] = True
        torch.index_select(node_mask, 0, row, out=edge_mask)
        new_nodes = col[edge_mask]
        tmp = []
        for node in new_nodes.tolist():
            if node in visited:
                continue
            tmp.append(node)
            label[node].append(h + 2)
        if len(tmp) == 0:
            break
        if max_nodes_per_hop is not None:
            if max_nodes_per_hop < len(tmp):
                tmp = random.sample(tmp, max_nodes_per_hop)
        new_nodes = set(tmp)
        visited = visited.union(new_nodes)
        new_nodes = torch.tensor(list(new_nodes), device=row.device)
        subsets.append(new_nodes)
        if node_label == 'hop':
            hops.append(torch.LongTensor([h + 1] * len(new_nodes), device=row.device))
    subset = torch.cat(subsets)
    inverse_map = torch.tensor(range(subset.shape[0]))
    if node_label == 'hop':
        hop = torch.cat(hops)
    # Add `node_idx` to the beginning of `subset`.
    subset = subset[subset != node_idx]
    subset = torch.cat([torch.tensor([node_idx], device=row.device), subset])

    z = None
    if node_label == 'hop':
        hop = hop[hop != 0]
        hop = torch.cat([torch.LongTensor([0], device=row.device), hop])
        z = hop.unsqueeze(1)
    elif node_label.startswith('spd') or node_label == 'drnl':
        if node_label.startswith('spd'):
            # keep top k shortest-path distances
            num_spd = int(node_label[3:]) if len(node_label) > 3 else 2
            z = torch.zeros(
                [subset.size(0), num_spd], dtype=torch.long, device=row.device
            )
        elif node_label == 'drnl':
            # see "Link Prediction Based on Graph Neural Networks", a special
            # case of spd2
            num_spd = 2
            z = torch.zeros([subset.size(0), 1], dtype=torch.long, device=row.device)

        for i, node in enumerate(subset.tolist()):
            dists = label[node][:num_spd]  # keep top num_spd distances
            if node_label == 'spd':
                z[i][:min(num_spd, len(dists))] = torch.tensor(dists)
            elif node_label == 'drnl':
                dist1 = dists[0]
                dist2 = dists[1] if len(dists) == 2 else 0
                if dist2 == 0:
                    dist = dist1
                else:
                    dist = dist1 * (num_hops + 1) + dist2
                z[i][0] = dist

    node_mask.fill_(False)
    node_mask[subset] = True
    edge_mask = node_mask[row] & node_mask[col]

    edge_index = edge_index[:, edge_mask]

    if relabel_nodes:  # GOOD CODING
        node_idx = row.new_full((num_nodes,), -1)
        node_idx[subset] = torch.arange(subset.size(0), device=row.device)
        edge_index = node_idx[edge_index]

    return subset, edge_index, edge_mask, z


def maybe_num_nodes(index, num_nodes=None):
    return index.max().item() + 1 if num_nodes is None else num_nodes


def bound_list_generate(data_list, n, max_h):
    bound_list = [0] * max_h
    for data in data_list:
        x, edge_index, num_nodes = data.x, data.edge_index, n
        for ind in range(num_nodes):
            nodes_, edge_index_, edge_mask_, z_ = k_hop_subgraph(
                ind, max_h, edge_index, True, num_nodes, node_label='hop',
                max_nodes_per_hop=None)
            # print([sum(z_ == ii).item() for ii in range(1,max_h + 1)])
            for j in range(1, max_h + 1):
                bound_list[j - 1] = max(bound_list[j - 1], sum(z_ == j).item())
    return bound_list


def khop_feature_trans(data, h=1, k=3, sample_ratio=1.0, max_nodes_per_hop=None,
                       node_label='hop', use_rd=False, use_ss=False, bound_list=[1, 10, 15]):
    assert (isinstance(data, Data))
    # bound_list = [k**i for i in range(1, h+1)]
    x, edge_index, num_nodes = data.x, data.edge_index, data.num_nodes

    # subgraph_x = [] # Since x need to use another encoder to transfer, here we simply remember the index to redue the operations
    subg_nodes = []
    subg_nodes_seq = []
    subg_edges = []
    subg_masks = []
    # subg_size = []
    # subgraph_adj = [] since adj need to get the edge encoding, we again use the feature transfer
    subg_rd = []
    # subgraph_ss = [] # now we do not consider the steady state

    # feat_dim = x.shape[1] # THIS MIGHT NEED TO BE FIXED
    total_nodes = sum(bound_list) + 1
    data.total_seq_nodes = total_nodes
    data.number_edges = edge_index.size(1)
    subg_num_nodes = []
    subg_num_edges = []
    subg_mask_size = []

    for ind in range(num_nodes):
        nodes_, edge_index_, edge_mask_, z_ = k_hop_subgraph(
            ind, h, edge_index, True, num_nodes, node_label=node_label,
            max_nodes_per_hop=max_nodes_per_hop
        )
        seq_index, seq_edge_index = seq_label_trans(z_, edge_index_, bound_list)

        subg_num_nodes.append(nodes_.size()[0])
        subg_num_edges.append(sum(edge_mask_).item())
        subg_mask_size.append(edge_mask_.size()[0])

        subg_nodes.append(nodes_)  # extrat node embeddings
        subg_nodes_seq.append(seq_index)  # positions in the sequence

        # subg_edges.append(seq_edge_index.flatten())  # positions in the sequence
        subg_edges.append(seq_edge_index)
        subg_masks.append(edge_mask_)  # extract edge embeddings, size = all_number_edges
        # subg_size.append([nodes_.size(0)])

        # sub_adj_ = torch.zeros(total_nodes,total_nodes)
        # sub_adj_[seq_edge_index[0],seq_edge_index[1]] = 1
        # subgraph_adj.append(sub_adj_.flatten())
        # if data.edge_attr is not None:
        #    edge_attr_ = data.edge_attr[edge_mask_]

        if use_rd:
            # See "Link prediction in complex networks: A survey".
            adj = to_scipy_sparse_matrix(
                edge_index_, num_nodes=nodes_.shape[0]
            ).tocsr()
            laplacian = ssp.csgraph.laplacian(adj).toarray()
            try:
                L_inv = linalg.pinv(laplacian)
            except:
                laplacian += 0.01 * np.eye(*laplacian.shape)
            lxx = L_inv[0, 0]
            lyy = L_inv[list(range(len(L_inv))), list(range(len(L_inv)))]
            lxy = L_inv[0, :]
            lyx = L_inv[:, 0]
            rd_to_x = torch.FloatTensor((lxx + lyy - lxy - lyx)).unsqueeze(1)
            subg_rd.append(rd_to_x.squeeze())

    data.subg_nodes = torch.cat(subg_nodes, dim=0)
    data.subg_nodes_seq = torch.cat(subg_nodes_seq, dim=0)
    data.subg_edges = torch.cat(subg_edges, dim=1).reshape(-1, 2)
    data.subg_masks = torch.cat(subg_masks, dim=0)
    if use_rd:
        data.subg_rd = torch.cat(subg_rd, dim=0)
    data.subg_node_size = subg_num_nodes
    data.subg_edge_size = subg_num_edges
    data.subg_mask_size = subg_mask_size

    return data


def canlabel1(data, direct=False):
    # 
    x = data.x
    if x == None:
        num_node = data.edge_index.max().item() + 1
    else:
        num_node = data.x.shape[0]
    
    edge_index = data.edge_index
    edges = [(i,j) for i,j in zip(edge_index[0],edge_index[1])]
    # get the within graph label
    #in_color = within_graph_color(data)
    # get the canonical labels
    g = igraph.Graph(directed=direct)
    g.add_vertices(num_node)
    g.add_edges(edges)
    
    cl = g.canonical_permutation()
    #cl2orb = cl_to_orbit(cl, orbits)
    # remember information
    #data.cl2orb = torch.LongTen1sor(cl2orb)
    data.cl = torch.LongTensor(cl)
    #data.orbits = torch.LongTensor(orbits)
    data.num_vertex = num_node 
    return data

def within_graph_color(data):
    node_feat_set = []
    for x_ in data.x:
        x_s = list(x_.numpy())
        if x_s not in node_feat_set:
            node_feat_set.append(x_s)
    node_feat_set_sort = sorted(node_feat_set)
    node_feat_set_sort_str = [str(i) for i in node_feat_set_sort]
    node_feat_dict = {x: i for x, i in zip(node_feat_set_sort_str, range(len(node_feat_set_sort_str)))}
    in_color = []
    for x_ in data.x:
        x_s = str(list(x_.numpy()))
        in_color.append(node_feat_dict[x_s])
    return in_color


" Models"


class NestedGIN(torch.nn.Module):
    def __init__(self, num_layers, hidden):
        super(NestedGIN, self).__init__()
        self.conv1 = GINConv(
            Sequential(
                Linear(1, hidden),
                ReLU(),
                Linear(hidden, hidden),
                ReLU(),
            ),
            train_eps=False)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(
                GINConv(
                    Sequential(
                        Linear(hidden, hidden),
                        ReLU(),
                        Linear(hidden, hidden),
                        ReLU(),
                    ),
                    train_eps=False))
        self.lin1 = torch.nn.Linear(hidden, hidden)
        self.lin2 = Linear(hidden, hidden)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        edge_index, batch = data.edge_index, data.batch
        if 'x' in data:
            x = data.x
        else:
            x = torch.ones([data.num_nodes, 1]).to(edge_index.device)
        x = self.conv1(x, edge_index)
        for conv in self.convs:
            x = conv(x, edge_index)

        x = global_add_pool(x, data.node_to_subgraph)
        if args.graph:
            x = global_add_pool(x, data.subgraph_to_graph)

        return x

    def __repr__(self):
        return self.__class__.__name__



class CLGCN(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden, cl_dim, *args, **kwargs):
        super(CLGCN, self).__init__()
        self.feat_map = torch.nn.Linear(1, hidden)
        self.conv1 = GCNConv(hidden, hidden)
        self.cl_embedding = torch.nn.Embedding(cl_dim, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(GCNConv(hidden, hidden))
        self.lin1 = torch.nn.Linear(num_layers * hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.feat_map.reset_parameters()
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        edge_index, batch, cl =  data.edge_index, data.batch, data.cl
        if 'x' in data:
            x = data.x
        else:
            x = torch.ones([data.num_nodes, 1]).to(edge_index.device)
        #print(x)
        x = self.feat_map(x)
        cl_emb = self.cl_embedding(cl)
        x += cl_emb

        x = F.relu(self.conv1(x, edge_index))
        xs = [x]
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
            xs += [x]
        x = global_add_pool(torch.cat(xs, dim=1), batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__

class CLGIN(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden, cl_dim, *args, **kwargs):
        super(CLGIN, self).__init__()
        self.feat_map = torch.nn.Linear(1, hidden)
        self.conv1 = GINConv(
            Sequential(
                Linear(hidden, hidden),
                ReLU(),
                Linear(hidden, hidden),
                ReLU(),
                BN(hidden),
            ),
            train_eps=True)
        self.cl_embedding = torch.nn.Embedding(cl_dim, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(
                GINConv(
                    Sequential(
                        Linear(hidden, hidden),
                        ReLU(),
                        Linear(hidden, hidden),
                        ReLU(),
                        BN(hidden),
                    ),
                    train_eps=True))
        self.lin1 = torch.nn.Linear(num_layers * hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        edge_index, batch, cl =  data.edge_index, data.batch, data.cl
        if 'x' in data:
            x = data.x
        else:
            x = torch.ones([data.num_nodes, 1]).to(edge_index.device)
        x = self.feat_map(x)
        cl_emb = self.cl_embedding(cl)
        x += cl_emb
        x = self.conv1(x, edge_index)
        xs = [x]
        for conv in self.convs:
            x = conv(x, edge_index)
            xs += [x]
        x = global_add_pool(torch.cat(xs, dim=1), batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__


parser = argparse.ArgumentParser(description='Canonial Subgraph GNN for CSL datasets')
parser.add_argument('--dataset', type=str, default='CSL')  # Base GNN used, GIN or GCN
parser.add_argument('--model', type=str, default='GIN')  # Base GNN used, GIN or GCN
parser.add_argument('--h', type=int, default=3,
                    help='largest height of rooted subgraphs to simulate')
parser.add_argument('--k', type=int, default=3,
                    help='largest height of rooted subgraphs to simulate')
parser.add_argument('--hidden1', type=int, default=128,
                    help='largest height of rooted subgraphs to simulate')
parser.add_argument('--hidden2', type=int, default=64,
                    help='largest height of rooted subgraphs to simulate')

parser.add_argument('--layers', type=int, default=3)  # Number of GNN layers

parser.add_argument('--width', type=int, default=64)  # Dimensionality of GNN embeddings
parser.add_argument('--epochs', type=int, default=500)  # Number of training epochs
parser.add_argument('--learnRate', type=float, default=0.001)  # Learning Rate

parser.add_argument('--cuda_id', type=int, default=1, metavar='N',
                    help='id of GPU')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--all-gpus', action='store_true', default=False,
                    help='use all available GPUs')

parser.add_argument("--wo_path_encoding", action="store_true", help="If true, remove path encoding from model")
parser.add_argument("--wo_edge_feature", action="store_true", help="If true, remove edge feature from model")
parser.add_argument('--reprocess', action="store_true", help='If true, reprocess the dataset')

args = parser.parse_args()


def print_or_log(input_data, log=False, log_file_path="Debug.txt"):
    if not log:  # If not logging, we should just print
        print(input_data)
    else:  # Logging
        log_file = open(log_file_path, "a+")
        log_file.write(str(input_data) + "\r\n")
        log_file.close()  # Keep the file available throughout execution


class MyFilter(object):
    def __call__(self, data):
        return True  # No Filtering


class MyPreTransform(object):
    def __call__(self, data):
        data.x = F.one_hot(data.x[:, 0], num_classes=2).to(torch.float)  # Convert node labels to one-hot
        return data



def cl_counter(dataset):
    pbar = tqdm(range(len(dataset)))
    max_cl = 1
    for i in pbar:
        data = canlabel1(dataset[i])
        cl= torch.max(data.cl).item()
        max_cl = max(max_cl, cl)
    return max_cl




# Command Line Arguments
DATASET = args.dataset
LAYERS = args.layers
EPOCHS = args.epochs
WIDTH = args.width
LEARNING_RATE = args.learnRate


args.res_dir = 'results/{}'.format(args.dataset)
if not os.path.exists(args.res_dir):
    os.makedirs(args.res_dir)
log_file = os.path.join(args.res_dir, 'log.txt')

# MODEL = f"Nested{args.model}-"
MODEL = "CLGNN"

if LEARNING_RATE != 0.001:
    MODEL = MODEL + "lr" + str(LEARNING_RATE) + "-"

BATCH = 16
MODULO = 4
MOD_THRESH = 1

path = 'data/' + DATASET
pre_transform = None

transform = post_transform(args.wo_path_encoding, args.wo_edge_feature)

def pre_transform(g):
    return canlabel1(g)

if args.reprocess:
    shutil.rmtree(path + '/processed')



dataset = GNNBenchmarkDataset(root=path, name=DATASET,
                              pre_transform=pre_transform,
                              transform=transform)
cl_dim = cl_counter(dataset) + 1

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args.cuda = not args.no_cuda and torch.cuda.is_available()
# torch.manual_seed(args.seed)
if args.cuda:
    # torch.cuda.manual_seed(args.seed)
    device = torch.device("cuda:{}".format(args.cuda_id))
else:
    device = torch.device("cpu")
if args.model == 'GIN':
    model = CLGIN(dataset=dataset, num_layers=args.layers, hidden=args.width, cl_dim=cl_dim).to(device)
    #model = NestedGIN(args.layers, args.width).to(device)
elif args.model == 'GCN':
    model =CLGCN(dataset=dataset, num_layers=args.layers, hidden=args.width, cl_dim=cl_dim).to(device)
    #model = NestedGCN(args.layers, args.width).to(device)
else:
    raise NotImplementedError('model type not supported')

if args.cuda:
    torch.cuda.manual_seed(1)
    device = torch.device("cuda:{}".format(args.cuda_id))
else:
    device = torch.device("cpu")
np.random.seed(1)

def train(epoch, loader, optimizer):
    model.train()
    loss_all = 0

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        loss = F.nll_loss(model(data), data.y)
        loss.backward()
        loss_all += data.num_graphs * loss.item()
        optimizer.step()
    return loss_all / len(loader.dataset)


def val(loader):
    model.eval()
    loss_all = 0

    for data in loader:
        data = data.to(device)
        loss_all += F.nll_loss(model(data), data.y, reduction='sum').item()
    return loss_all / len(loader.dataset)


def test(loader):
    model.eval()
    correct = 0

    for data in loader:
        data = data.to(device)
        pred = torch.argmax(model(data),dim=1)
        successful_trials = pred.eq(data.y)
        correct += successful_trials.sum().item()
    return correct / len(loader.dataset)


acc = []
tr_acc = []
# SPLITS = 2
SPLITS = 10
tr_accuracies = np.zeros((EPOCHS, SPLITS))
tst_accuracies = np.zeros((EPOCHS, SPLITS))
tst_exp_accuracies = np.zeros((EPOCHS, SPLITS))
tst_lrn_accuracies = np.zeros((EPOCHS, SPLITS))
idx = list(range(len(dataset)))
shuffle(idx)
idx = torch.LongTensor(idx)

for i in range(SPLITS):
    model.reset_parameters()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.7, patience=5, min_lr=LEARNING_RATE)

    n = len(dataset) // SPLITS # 15
    test_mask = torch.zeros(len(dataset), dtype=torch.bool)
    test_exp_mask = torch.zeros(len(dataset), dtype=torch.bool)
    test_lrn_mask = torch.zeros(len(dataset), dtype=torch.bool)

    test_mask[idx[i * n:(i + 1) * n]] = 1

    #learning_indices = [x for idx, x in enumerate(range(n * i, n * (i + 1))) if x % MODULO <= MOD_THRESH]
    #test_lrn_mask[learning_indices] = 1
    #exp_indices = [x for idx, x in enumerate(range(n * i, n * (i + 1))) if x % MODULO > MOD_THRESH]
    #test_exp_mask[exp_indices] = 1

    # Now load the datasets
    test_dataset = dataset[test_mask]
    #test_exp_dataset = dataset[test_exp_mask]
    #test_lrn_dataset = dataset[test_lrn_mask]
    train_dataset = dataset[~test_mask]

    n = len(train_dataset) // SPLITS
    val_mask = torch.zeros(len(train_dataset), dtype=torch.bool)
    val_mask[i * n:(i + 1) * n] = 1
    val_dataset = train_dataset[val_mask]
    train_dataset = train_dataset[~val_mask]

    val_loader = DataLoader(val_dataset, batch_size=BATCH)
    test_loader = DataLoader(test_dataset, batch_size=BATCH)
    #test_exp_loader = DataLoader(test_exp_dataset, batch_size=BATCH)  # These are the new test splits
    #test_lrn_loader = DataLoader(test_lrn_dataset, batch_size=BATCH)
    train_loader = DataLoader(train_dataset, batch_size=BATCH, shuffle=True)

    print_or_log('---------------- Split {} ----------------'.format(i),
                 log_file_path="log" + MODEL + DATASET + "," + str(LAYERS) + "," + str(WIDTH) + ".txt")

    res = '---------------- Split {} ----------------'.format(i)
    with open(log_file, 'a') as f:
        print(res, file=f)

    best_val_loss, test_acc = 100, 0
    for epoch in tqdm(range(EPOCHS)):
        lr = scheduler.optimizer.param_groups[0]['lr']
        train_loss = train(epoch, train_loader, optimizer)
        val_loss = val(val_loader)
        scheduler.step(val_loss)
        if best_val_loss >= val_loss:
            best_val_loss = val_loss
        train_acc = test(train_loader)
        test_acc = test(test_loader)
        val_acc = test(val_loader)
        #test_exp_acc = test(test_exp_loader)
        #test_lrn_acc = test(test_lrn_loader)
        tr_accuracies[epoch, i] = train_acc
        tst_accuracies[epoch, i] = test_acc
        #tst_exp_accuracies[epoch, i] = test_exp_acc
        #tst_lrn_accuracies[epoch, i] = test_lrn_acc
        print_or_log(
            'Epoch: {:03d}, LR: {:7f}, Train Loss: {:.7f}, Val Loss: {:.7f}, Test Acc: {:.7f}, Val Acc: {:.7f}, Train Acc: {:.7f}'.format(
                epoch + 1, lr, train_loss, val_loss, test_acc, val_acc, train_acc),
            log_file_path="log" + MODEL + DATASET + "," + str(LAYERS) + "," + str(WIDTH) + ".txt")

        res = 'Epoch: {:03d}, LR: {:7f}, Train Loss: {:.7f}, Val Loss: {:.7f}, Test Acc: {:.7f}, Val Acc: {:.7f}, Train Acc: {:.7f}'.format(
            epoch + 1, lr, train_loss, val_loss, test_acc, val_acc, train_acc)
        with open(log_file, 'a') as f:
            print(res, file=f)

    acc.append(test_acc)
    tr_acc.append(train_acc)

acc = torch.tensor(acc)
tr_acc = torch.tensor(tr_acc)
print_or_log('---------------- Final Result ----------------',
             log_file_path="log" + MODEL + DATASET + "," + str(LAYERS) + "," + str(WIDTH) + ".txt")

res = '---------------- Final Result ----------------'
with open(log_file, 'a') as f:
    print(res, file=f)

print_or_log('Mean: {:7f}, Std: {:7f}'.format(acc.mean(), acc.std()),
             log_file_path="log" + MODEL + DATASET + "," + str(LAYERS) + "," + str(WIDTH) + ".txt")

res = 'Mean: {:7f}, Std: {:7f}'.format(acc.mean(), acc.std())
with open(log_file, 'a') as f:
    print(res, file=f)

print_or_log('Tr Mean: {:7f}, Std: {:7f}'.format(tr_acc.mean(), tr_acc.std()),
             log_file_path="log" + MODEL + DATASET + "," + str(LAYERS) + "," + str(WIDTH) + ".txt")

res = 'Tr Mean: {:7f}, Std: {:7f}'.format(tr_acc.mean(), tr_acc.std())
with open(log_file, 'a') as f:
    print(res, file=f)