import numpy as np
from sklearn import preprocessing
from sklearn import metrics
import time
from tqdm import tqdm
import argparse
from torch_geometric import seed_everything
from temporal_graph.transforms import RandomNodeSplit, ToTemporalUndirected, StratifyNodeSplit
from temporal_graph.datasets import DBLP, Tmall, STARDataset
from torch_geometric.transforms import Compose, NormalizeFeatures
import torch.nn.functional as F
import torch
from logger import setup_logger
import os.path as osp




def meanstd_normalization_tensor(tensor):
    # tensor: n_node, n_steps, n_dim
    n_node, n_steps, n_dim = tensor.shape
    # tensor_norm = np.ones([n_node, n_steps, n_dim])
    tensor_reshape = preprocessing.scale(
        np.reshape(tensor, [n_node, n_steps*n_dim]), axis=1)
    tensor_norm = np.reshape(tensor_reshape, [n_node, n_steps, n_dim])

    # print norm_x
    return torch.from_numpy(tensor_norm)


parser = argparse.ArgumentParser()
# ['DBLP', 'dblp3', 'dblp5', 'reddit', 'brain']:
parser.add_argument('--seed', type=int, default=2024)
parser.add_argument('--dataset', type=str, default="DBLP")
parser.add_argument('--hidden_channels', type=int, default=32)
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--learning_rate', type=float, default=0.01)
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--ssm_format', type=str, default='siso')
parser.add_argument('--token_mixer', type=str, default='interp')
parser.add_argument('--train_size', type=float, default=0.8)
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--special', action='store_true', default=False)
parser.add_argument('--log_name', type=str, default="")
parser.add_argument('--model_name', type=str, default="ssm")
args = parser.parse_args()

seed_everything(args.seed)

if args.special:
    args.test_size = 0.1
    args.train_size = 0.81
    args.val_size = 0.09
else:
    args.test_size = 1 - args.train_size
    args.train_size = args.train_size - 0.05
    args.val_size = 0.05

if args.model_name == "ssm":
    from models.ssm import DiagonalSSM, SimpleRoland
elif args.model_name == "mamba":
    from models.mamba import DiagonalSSM, SimpleRoland
elif args.model_name == "mambav2":
    from models.mamba_v2 import DiagonalS6SSM as DiagonalSSM

config = {
    "dataset": args.dataset,
    "train_size": args.train_size,
    'hidden_channels': args.hidden_channels,
    'learning_rate': args.learning_rate,
    "weight_delay": args.weight_decay,
    "ssm_format": args.ssm_format,
    "token_mixer": args.token_mixer,
    'epochs': args.epochs
}


logger = setup_logger(
    output=f"logs{args.log_name}/{args.dataset}/{args.train_size:.2f}", name="test")
transform = Compose(
    [ToTemporalUndirected(),
     StratifyNodeSplit(num_val=args.val_size, num_test=args.test_size, unknown=-1)])

dataset = args.dataset
if dataset == 'DBLP':
    path = './data/DBLP'
    data = DBLP(root=path, transform=transform, force_reload=False)[0]
elif dataset in ['dblp3', 'dblp5', 'reddit', 'brain']:
    path = './data/'
    transform = Compose([StratifyNodeSplit(
        num_val=args.val_size, num_test=args.test_size, unknown=-1)])
    data = STARDataset(root=path, name=dataset,
                       transform=transform, force_reload=False)[0]

    # norm
    x_temp = []
    for i in range(data.num_snapshots):
        snapshot = data.snapshot(start=i)
        x_temp.append(snapshot.x.unsqueeze(1))

    x_temp = torch.cat(x_temp, dim=1)
    x_temp = meanstd_normalization_tensor(x_temp)

    for i in range(data.num_snapshots):
        snapshot = data.snapshot(start=i)
        x = x_temp[:, i, :]
        snapshot.x = x

print('Loading dataset...')
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'

print('To snapshots...')
snapshots = [data.snapshot(start=i) for i in range(data.num_snapshots)]
snapshots = [snapshot.to(device) for snapshot in snapshots]

print(snapshots[:2])
print(data)
model = DiagonalSSM(
    data.x.size(-1),
    data.y.max().item()+1,
    hidden_channels=config["hidden_channels"],
    ssm_format=config["ssm_format"],
    token_mixer=config["token_mixer"],
).to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=args.weight_decay)


def train():
    model.train()
    optimizer.zero_grad()
    out = model(snapshots)
    data = snapshots[-1]
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    pred = model(snapshots).argmax(dim=-1)
    data = snapshots[-1]
    metric_macros = []
    metric_micros = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        if mask.sum() == 0:
            metric_macros.append(0)
            metric_micros.append(0)
        else:
            metric_macros.append(metrics.f1_score(
                data.y[mask].cpu().numpy(), pred[mask].cpu().numpy(), average='macro'))
            metric_micros.append(metrics.f1_score(
                data.y[mask].cpu().numpy(), pred[mask].cpu().numpy(), average='micro'))
    return metric_macros, metric_micros


best_val = -1e5
best_test = -1e5
best_metric_macros = None

start_time = time.time()
for epoch in range(1, config['epochs']+1):
    loss = train()
    metric_macros, metric_micros = test()
    train_acc, val_acc, test_acc = metric_micros
    if best_val < val_acc:
        best_val = val_acc
        best_test = test_acc
        best_metric_macros = metric_macros
    logger.info(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
    logger.info(
        f'MICROS: Train: {train_acc:.2%}, Val: {val_acc:.2%}, Test: {test_acc:.2%}, Best Test: {best_test:.2%}')
    logger.info(
        f'MACROS: Train: {best_metric_macros[0]:.2%}, Val: {best_metric_macros[1]:.2%}, Test: {best_metric_macros[2]:.2%}')
end_tim = time.time()
logger.info(f'Time: {end_tim-start_time:.2f}s')
logger.info(config)
