import numpy as np
import torch
import json
import argparse
import pickle
import random

from graphloader import get_graph_data
from utils import prepare_dataset, prepare_model
from evaluation import mmd_evaluation, generate_samples


def set_seed(seed: int = 42):
    random.seed(seed)                        
    np.random.seed(seed)                     
    torch.manual_seed(seed)                  
    torch.cuda.manual_seed(seed)             

def set_everything(graph_name, num_signals, walk_type, **kwargs):
    """
    sets all configurations required for training including walk sampling, dataset creation, model configuration
    """
    with open('config.json', 'r') as file:
        config = json.load(file)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    input_graph, input_signal = get_graph_data(graph_name, num_signals, **kwargs)
    print(input_graph.number_of_nodes())
    model_config = config["model_config"]
    train_config = config["train_config"]
    data_config = config["data_config"][graph_name]

    model_config["block_size"] = data_config["walk_length"] - 1
    train_config["block_size"] = data_config["walk_length"] - 1
    model_config["vocab_size"] = input_graph.number_of_nodes()

    train_data, val_data = prepare_dataset(input_graph, input_signal, walk_type, data_config, **kwargs)
    model = prepare_model(model_config, device)

    return input_graph, input_signal, train_data, val_data, train_config, model

def get_batch(data,block_size,device,batch_size=32):
    idx = torch.randint(data.shape[0], (batch_size,))
    x = data[idx, :block_size]
    y = data[idx, 1:block_size+1]
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss(model, eval_iters, block_size, device, batch_size, train_data, val_data):
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            data = train_data if split == "train" else val_data
            x, y = get_batch(data, block_size, device, batch_size)
            logits, loss = model(x, y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return model, out 

def train_model(model, train_dataset, val_dataset, train_config, patience=800):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    lr = train_config.get("lr")
    max_iters = train_config.get("max_iters")
    eval_iters = train_config.get("eval_iters")
    eval_interval = train_config.get("eval_interval")
    block_size = train_config.get("block_size")
    batch_size = train_config.get("batch_size")
    patience = train_config.get("patience")

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    best_val_loss = float("inf")
    no_improve_count = 0
    for iter in range(max_iters):
        if (iter % eval_interval == 0) or (iter == max_iters - 1):
            model.eval()
            model, losses = estimate_loss(model, eval_iters, block_size, device, batch_size, train_dataset, val_dataset)
            if losses["val"] < best_val_loss:
                best_val_loss = losses["val"]
                no_improve_count = 0
            else:
                no_improve_count += eval_interval
                if no_improve_count >= patience:
                    break
            if iter % (eval_interval*4) == 0:
                print(f"Step: {iter} | Train Loss: {losses['train']:.4f} | Val Loss: {losses['val']:.4f}")
            
            model.train()
        
        xbatch, ybatch = get_batch(train_dataset, block_size, device, batch_size)
        logits, loss = model(xbatch, ybatch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()  
    return model


def evaluate_model(model, input_graph, num_tokens=800, n=20, save_as_img=False):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    generated_graphs = generate_samples(model, input_graph, num_tokens, n, device, save=save_as_img)

    with open("./gen_graphs.pkl", "wb") as f:
        pickle.dump(generated_graphs, f)
    
    mmd_evaluation(input_graph, generated_graphs)
    
    return generated_graphs


def main(args):
    set_seed(42)
    input_graph, input_signal, train_data, val_data, train_config, model = set_everything(args.graph_name, args.num_signals, args.walk_type, degree=args.temporal_degree)
    print(train_config)
    num_to_generate = train_config["num_to_generate"] # number of graphs to generate for evalation
    num_tokens = train_config["num_tokens"]  # minimum length of sequence while generating graph, increase for larger graphs
    model = train_model(model, train_data, val_data, train_config)

    evaluate_model(model, input_graph, save_as_img=False, num_tokens=num_tokens, n=num_to_generate)

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="A simple argument parser example")

    parser.add_argument('--graph-name', type=str, required=True, help="The graph dataset to train on")
    parser.add_argument('--walk-type', type=str, required=True, help="Walk type to be used: [random, random_plus, brn, brn_plus]")
    parser.add_argument('--num-signals', type=int, default=168, help="Number of timestamps to use (node signals)")
    parser.add_argument("--temporal-degree", type=int, default=1, help="One-hop or second-hop temporal bias")
    parser.add_argument("--p", type=float, default=1, help="p (return parameter) for biased random walk")
    parser.add_argument("--q", type=float, default=0.01, help="q (in-out parameter) for biased random walk")

    args = parser.parse_args()

    main(args)