from sklearn import preprocessing
import numpy as np
import torch.nn.functional as F
import torch
import os.path as osp
from temporal_graph.datasets import STARDataset
from temporal_graph.models import STAR
from temporal_graph.transforms import RandomNodeSplit, ToTemporalUndirected, StratifyNodeSplit
from logger import setup_logger
from torch_geometric.transforms import Compose
from torch_geometric import seed_everything
import time
from sklearn import metrics
import argparse





def minmax_normalization_tensor(tensor):
    # tensor: n_node, n_steps, n_dim
    n_node, n_steps, n_dim = tensor.shape
    tensor_norm = np.ones([n_node, n_steps, n_dim])
    for i in range(tensor.shape[1]):
        mat = tensor[:, i, :]
        max_val = np.max(mat, 0)  # shape: n_dim
        min_val = np.min(mat, 0)
        mat_norm = (mat - min_val) / (max_val - min_val + 1e-12)

        tensor_norm[:, i, :] = mat_norm

    # print norm_x
    return tensor_norm


def meanstd_normalization_tensor(tensor):
    # tensor: n_node, n_steps, n_dim
    n_node, n_steps, n_dim = tensor.shape
    # tensor_norm = np.ones([n_node, n_steps, n_dim])
    tensor_reshape = preprocessing.scale(
        np.reshape(tensor, [n_node, n_steps*n_dim]), axis=1)
    tensor_norm = np.reshape(tensor_reshape, [n_node, n_steps, n_dim])

    # print norm_x
    return torch.from_numpy(tensor_norm)


parser = argparse.ArgumentParser()
# ['dblp3', 'dblp5', 'reddit', 'brain']:
parser.add_argument('--seed', type=int, default=2024)
parser.add_argument('--dataset', type=str, default='dblp3')
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--learning_rate', type=float, default=0.01)
parser.add_argument('--train_size', type=float, default=0.8)
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--special', action='store_true', default=False)
parser.add_argument('--log_name', type=str, default="")
args = parser.parse_args()


seed_everything(args.seed)

if args.special:
    args.test_size = 0.1
    args.train_size = 0.81
    args.val_size = 0.09
else:
    args.test_size = 1 - args.train_size
    args.train_size = args.train_size - 0.05
    args.val_size = 0.05


logger = setup_logger(
    output=f"logs{args.log_name}/{args.dataset}/{args.train_size:.2f}", name="test")
dataset = args.dataset
assert dataset in ['dblp3', 'dblp5', 'reddit', 'brain']

epochs = args.epochs
if dataset == "reddit":
    epochs = 200

path = './data/'
transform = Compose([StratifyNodeSplit(
    num_val=args.val_size, num_test=args.test_size, unknown=-1)])
data = STARDataset(root=path, name=dataset,
                   transform=transform, force_reload=False)[0]

device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'

x_temp = []
for i in range(data.num_snapshots):
    snapshot = data.snapshot(start=i)
    x_temp.append(snapshot.x.unsqueeze(1))

x_temp = torch.cat(x_temp, dim=1)
x_temp = meanstd_normalization_tensor(x_temp)

for i in range(data.num_snapshots):
    snapshot = data.snapshot(start=i)
    x = x_temp[:, i, :]
    snapshot.x = x


snapshots = [data.snapshot(start=i) for i in range(data.num_snapshots)]
snapshots = [snapshot.to(device) for snapshot in snapshots]

model = STAR(data.x.size(-1), data.y.max().item()+1,).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
reg_attn = 1e-6


def train():
    model.train()
    optimizer.zero_grad()
    out, A = model(snapshots, return_attention=True)
    data = snapshots[-1]
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    identidy = torch.eye(A.size(-1), device=A.device)
    loss += reg_attn * torch.norm(A - identidy).square() / 2
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    pred = model(snapshots).argmax(dim=-1)
    data = snapshots[-1]
    metric_macros = []
    metric_micros = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        if mask.sum() == 0:
            metric_macros.append(0)
            metric_micros.append(0)
        else:
            metric_macros.append(metrics.f1_score(
                data.y[mask].cpu().numpy(), pred[mask].cpu().numpy(), average='macro'))
            metric_micros.append(metrics.f1_score(
                data.y[mask].cpu().numpy(), pred[mask].cpu().numpy(), average='micro'))
    return metric_macros, metric_micros


best_val = -1e5
best_test = -1e5
best_metric_macros = None

start_time = time.time()
for epoch in range(1, epochs+1):
    loss = train()
    metric_macros, metric_micros = test()
    train_acc, val_acc, test_acc = metric_micros
    if best_val < val_acc:
        best_val = val_acc
        best_test = test_acc
        best_metric_macros = metric_macros
    logger.info(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
    logger.info(
        f'MICROS: Train: {train_acc:.2%}, Val: {val_acc:.2%}, Test: {test_acc:.2%}, Best Test: {best_test:.2%}')
    logger.info(
        f'MACROS: Train: {best_metric_macros[0]:.2%}, Val: {best_metric_macros[1]:.2%}, Test: {best_metric_macros[2]:.2%}')
end_tim = time.time()
logger.info(f'Time: {end_tim-start_time:.2f}s')
