import torch
import numpy as np
import json
import networkx as nx

from generate_walks import generate_walks
from model import TransformerGG

def prepare_dataset(input_graph, input_signals, walk_type, config, **kwargs):
    """
    prepares train and validation dataset to be used in training
    """
    num_walks_per_node = config.get("num_walks_per_node")
    walk_length = config.get("walk_length")
    ratio = config.get("ratio", 0.8)
    train_split = config.get("train_split", 0.7)
    print(f"Num walks per node: {num_walks_per_node}\nWalk length: {walk_length}\nRatio: {ratio}")

    similarity_matrix = None
    if walk_type.endswith("plus"):
        similarity_matrix = get_similarity_matrix(input_signals)
    
    walks = generate_walks(input_graph, num_walks_per_node, walk_length, walk_type, similarity_matrix, ratio, **kwargs)
    dataset = []

    for node in walks.keys():
        node_tensor = torch.zeros(num_walks_per_node, walk_length, dtype=torch.long)
        for walk_no in range(num_walks_per_node):
            walk = torch.tensor(walks[node][walk_no], dtype=torch.long)
            node_tensor[walk_no] = walk
        dataset.append(node_tensor)
    dataset = torch.stack(dataset)

    train_split = int(dataset.shape[1] * train_split)
    train_data = dataset[:,:train_split,:].reshape(-1, walk_length)
    val_data = dataset[:,train_split:,:].reshape(-1, walk_length)

    return train_data, val_data

def series2graph(series):
    """
    converts series (list of nodes) into nx.Graph object
    """
    generated = nx.Graph()
    for i in range(len(series)-1):
        if series[i] != series[i+1]:
            generated.add_edge(series[i], series[i+1])
    return generated


def prepare_model(train_config, device):
    """
    sets model configurations and initialize it
    """
    block_size = train_config.get("block_size")
    n_embd = train_config.get("n_embd")
    n_head = train_config.get("n_head")
    n_layer = train_config.get("n_layer")
    dropout = train_config.get("dropout")
    vocab_size = train_config.get("vocab_size")

    model = TransformerGG(vocab_size, n_embd, n_head, block_size, n_layer, dropout, device).to(device)
    return model

def get_similarity_matrix(input_signal, sim_type="corr"):
    """
    given feature matrix, calculates similarity matrix (S)
    """
    if sim_type == "corr":
        print(input_signal.shape)

        if len(input_signal.shape) > 2:
            N = input_signal.shape[1]
            n_channels = input_signal.shape[2]

            # Flatten time + channels for correlation
            node_vectors = input_signal[5:].reshape(input_signal.shape[0]-5, N, n_channels).transpose(1,0,2).reshape(N, -1)
            print(node_vectors.shape)
            # Compute similarity
            similarity_matrix = np.corrcoef(node_vectors)  # N x N
            print(similarity_matrix.shape)
        else:
            similarity_matrix = np.corrcoef(input_signal, rowvar=False)  # N x N
    
        # corrc = gaussian_filter(corrc, 7, axes=1)
        # min_vals = corrc.min(axis=1, keepdims=True)
        # max_vals = corrc.max(axis=1, keepdims=True)
        # similarity_matrix = (corrc - min_vals)  / (max_vals - min_vals + 1e-8)
    
    return similarity_matrix