import torch
from types import SimpleNamespace
from val_epoch import Validator
from model import (
    GraphAttentionEmbedding,
    ProductLayer,
    ReshapeLayer,
    col_RNN,
    MergeLayer,
)
from torch_geometric.nn import TGNMemory
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn.models.tgn import (
    IdentityMessage,
    LastAggregator,
    LastNeighborLoader,
)
from torch_geometric.data import TemporalData
from bikeshare import get_cbs_data
from generate import generate
from model import preprocess


def get_synthetic_data():

    NAME = "wikipedia"
    MODEL_TYPE = "train_loss"
    MODEL_DESC = "concat"
    PATH = f"./saved_models/{MODEL_TYPE}/{NAME}_{MODEL_DESC}"
    DATA_PATH = f"./data/{NAME}"
    SEED = 100
    BATCH_SIZE = 200
    NUM_BATCHES = 500
    OUT_PATH = f"./ablation_data/{NAME}_bs{BATCH_SIZE}_nb{NUM_BATCHES}_{MODEL_DESC}"

    checkpoint = torch.load(PATH)
    cfg = SimpleNamespace(**checkpoint["config_dict"])
    device = "cuda"

    name = NAME
    print(f"Using {name} dataset. ")

    data = None
    if name in ["wikipedia", "reddit"]:
        path = DATA_PATH
        dataset = JODIEDataset(path, name=name)
        data = dataset[0]

    else:
        print("for sure bike data!!")
        data = get_cbs_data()

    data = data.to(device)

    memory = TGNMemory(
        cfg.data.num_nodes,
        cfg.data.msg.size(-1),
        cfg.memory_dim,
        cfg.time_dim,
        message_module=IdentityMessage(
            cfg.data.msg.size(-1), cfg.memory_dim, cfg.time_dim
        ),
        aggregator_module=LastAggregator(),
    )
    memory.eval()

    gnn = GraphAttentionEmbedding(
        in_channels=cfg.memory_dim,
        out_channels=cfg.embedding_dim,
        msg_dim=cfg.data.msg.size(-1),
        time_enc=memory.time_enc,
    ).to(device)

    embd_to_score_dst = ProductLayer(in_channels=cfg.embedding_dim).to(device)

    embd_to_score_src = ReshapeLayer(in_channels=cfg.embedding_dim, out_channels=1).to(
        device
    )

    print(cfg.smp.shape)
    print(cfg.memory_dim)

    feats_model = col_RNN(
        cfg.smp.shape[1],
        cfg.col2K,
        embed_dim=1,
        hidden_size=cfg.feats_model_h_dim,
        num_layers=1,
        n_comp=cfg.n_comp,
    ).to(device)

    embd_to_h0 = MergeLayer(
        in_channels=cfg.embedding_dim, out_channels=cfg.feats_model_h_dim
    ).to(device)

    neighbor_loader = LastNeighborLoader(
        cfg.data.num_nodes, size=cfg.num_neighbors, device=device
    )

    memory.load_state_dict(checkpoint["memory_state_dict"])
    gnn.load_state_dict(checkpoint["gnn_state_dict"])
    embd_to_score_dst.load_state_dict(checkpoint["embd_to_score_dst_state_dict"])
    embd_to_score_src.load_state_dict(checkpoint["embd_to_score_src_state_dict"])
    feats_model.load_state_dict(checkpoint["feats_model_state_dict"])
    embd_to_h0.load_state_dict(checkpoint["embd_to_h0_state_dict"])
    memory = memory.cuda()
    gnn = gnn.cuda()
    embd_to_score_dst = embd_to_score_dst.cuda()
    feats_model = feats_model.cuda()
    embd_to_h0 = embd_to_h0.cuda()

    memory.time_enc.cuda()
    memory.gru.cuda()
    memory.msg_s_module.cuda()
    memory.msg_d_module.cuda()
    memory.aggr_module.cuda()

    neighbor_loader = checkpoint["neighbor_loader"]

    """
    START GENERATE
    """
    all_n_id = torch.arange(data.num_nodes).long().to(device)
    embeddings = torch.empty([data.num_nodes, cfg.embedding_dim]).to(device)

    data_synthetic = generate(
        best_memory=memory,
        best_gnn=gnn,
        best_embd_to_score_src=embd_to_score_src,
        best_embd_to_score_dst=embd_to_score_dst,
        best_feats_model=feats_model,
        best_embd_to_h0=embd_to_h0,
        neighbor_loader=neighbor_loader,
        all_n_id=all_n_id,
        embeddings=embeddings,
        seed=SEED,
        all_src=cfg.all_src,
        all_dst=cfg.all_dst,
        min_dst_idx=cfg.min_dst_idx,
        preprocess=preprocess,
        batch_size=BATCH_SIZE,
        num_batches=NUM_BATCHES,
    )

    data_synthetic = TemporalData(
        src=data_synthetic.src,
        dst=data_synthetic.dst,
        t=data_synthetic.t,
        msg=data_synthetic.msg,
    )

    print(data_synthetic.msg.shape)

    torch.save(data_synthetic, OUT_PATH)

    return data_synthetic


if __name__ == "__main__":
    data = get_synthetic_data()
