# %%
from datasets.triangle import TriangleSCM
from datasets.chain import ChainSCM
from datasets.collider import ColliderSCM
from datasets.mgraph import MGraphSCM
from datasets.loan import LoanSCM
from datasets.german import GermanSCM


from datasets.transforms import ToTensor
import torch

dataset_lin = LoanSCM(equations_type='linear',
                          transform=ToTensor())

# dataset_lin = LoanSCM(equations_type='linear',transform=ToTensor())
dataset_lin.prepare_data(1000, add_self_loop=True)

dataset_lin.set_intervention(x_I={'x3': 0})
batch = dataset_lin.__getitem__(0)

# batch.x_i = torch.cat([batch.x, batch.x_i], dim=1)
print(batch)
print(batch.x)
print(batch.edge_index)

# %%
from datasets._adjacency import Adjacency
import numpy as np
my_adj = np.array([[1, 1, 0, 0],
                  [0, 1, 1, 0],
                  [0, 0, 1, 1],
                  [0, 0, 0, 1]])
adj = Adjacency(my_adj)

print('Adjacency Matrix')
print(adj.adj_matrix)
print(adj.edge_index)
print(adj.edge_attr)



adj.set_intervention(node_id_list=[2], add_self_loop=True)
print('Adjacency Matrix Intervnetion')
print(adj.adj_matrix_i)
print(adj.edge_index_i)
print(adj.edge_attr_i)


# %%  ZINC Dataset. Chemical dataset. Graph classification
from torch_geometric.datasets import ZINC
import numpy as np

dataset = ZINC(root='../Data/graph',
               subset=True,
               split='train',
               transform=None,
               pre_transform=None,
               pre_filter=None)

num_nodes_list = []
num_edges_list = []
for idx in range(dataset.len()):
    data = dataset.get(idx)
    num_nodes_list.append(data.x.shape[0])
    num_edges_list.append(data.edge_index.shape[1])

print(np.max(num_edges_list))
print(np.min(num_edges_list))
print(np.max(num_nodes_list))
print(np.min(num_nodes_list))

# %% Test Twitter SE

from datasets.twitter_se import TwitterSE

dataset = TwitterSE(root_dir='../Data',
                    T=30, L=10, Q=0, sigma=8,
                    transform=None,
                    features='11111111111')

dataset.prepare_data()

# %%

for i, l in enumerate(dataset.x_labels):
    print(f'{i} : {l}')

# %%

prop_members = dataset.y.mean()
print(f'Proportion of members: {prop_members * 100:.2f} %')

# %%
cond = dataset.x_se[:, 11] > 0
x_tmp = dataset.x_se[cond]
y_tmp = dataset.y[cond]
print(f'Proportion of members: {y_tmp.mean() * 100:.2f} %')

# %%

cond = dataset.y == 1
x_tmp = dataset.x_se[cond]
y_tmp = dataset.y[cond]
print(f'Proportion of members: {y_tmp.mean() * 100:.2f} %')
print(f'Proportion of feature 22: {x_tmp[:, 22].mean() * 100:.2f} %')
# %%

cond = (dataset.y == 0) | (dataset.x_se[:, 22] > 0)
x_tmp = dataset.x_se[cond]
y_tmp = dataset.y[cond]
print(f'Proportion of members: {y_tmp.mean() * 100:.2f} %')
print(f'Proportion of feature 22: {x_tmp[:, 22].mean() * 100:.2f} %')

# %% Logistic regression
import numpy as np
from sklearn.model_selection import train_test_split

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, accuracy_score, roc_auc_score

from sklearn.preprocessing import StandardScaler

X = dataset.x_se.numpy()
y = dataset.y.numpy()

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=14)
print(f'y_mean_train: {np.mean(y_train) * 100:.2f} y_mean_test: {np.mean(y_test) * 100:.2f}')

# %% Test Twitter WALL

from datasets.twitter_wall import TwitterWall
import numpy as np

dataset = TwitterWall(root='../Data',
                      T=30, L=10, Q=0, sigma=8,
                      transform=None,
                      only_id=None,
                      pre_transform=None,
                      features='11111111111')
print(len(dataset))

num_nodes_list = []
num_edges_list = []
for data in dataset:
    num_nodes_list.append(data.x.shape[0])
    num_edges_list.append(data.edge_index.shape[1])

print(np.max(num_edges_list))
print(np.min(num_edges_list))
print(np.max(num_nodes_list))
print(np.min(num_nodes_list))

# %%
from torch_geometric.utils import degree
import torch

d_list = []
idx = 1  # indegree
for data in dataset:
    d = degree(data.edge_index[idx], num_nodes=data.num_nodes, dtype=torch.long)
    d_list.append(d)

print(torch.mean(d.float()))
deg1 = torch.bincount(d, minlength=d.numel()).float()
print(torch.mean(deg1))
d = torch.cat(d_list)
print(torch.mean(d.float()))
deg = torch.bincount(d, minlength=d.numel()).float()
print(torch.mean(deg))

# %%
import torch

x_list = []
edge_list = []
for i, data in enumerate(dataset):
    print(i)
    x_list.append(data.x)
    edge_list.append(data.edge_attr)
x_total = torch.cat(x_list, 0)
edge_total = torch.cat(edge_list, 0)
from sklearn.preprocessing import PowerTransformer

power = PowerTransformer(method='yeo-johnson', standardize=True).fit(x_total)
power_e = PowerTransformer(method='yeo-johnson', standardize=True).fit(edge_total[:, 1:])

# %%
x_tr = power.transform(x_total)
print(x_tr.max(0))
e_tr = power_e.transform(edge_total[:, 1:])
print(e_tr.max(0))
# %%
data = dataset.get(19)
print(data.x.shape[0])
print(data.y_label.shape)

print((data.y_label == 1).sum())
print(data.is_missing.sum())

# %% Twitter SE (Seeker-expert)
import torch

x_se = data.x[data.y_index].view(-1, 22)
edge_se = torch.zeros([x_se.shape[0], data.edge_attr.shape[1]])

# The source is the seeker and the target is not the seeker
cond = (data.edge_index[0, :] == 0) & (data.edge_index[1, :] != 0)
edge_index_se = data.edge_index[:, cond]

edge_se[edge_index_se[1, :] - 1, :] = data.edge_attr[cond, :]

# %%
edge_index_s = torch.where(data.edge_index[0, :] == 0)  # Index of edges in which the seeker is the source
edge_se = torch.zeros([x_se.shape[0], data.edge_attr.shape[1]])

edge_se = data.edge_attr[edge_index_s]
# %%
import torch

# %%
import utils.args_parser as atools
import os
import glob
import torch_geometric.utils.convert as gconvert
import datasets.utils as dutils
import networkx as nx

T, L, Q, sigma = 30, 10, 0, 8
root = f'../Data/TwitterWalls/{T}_{L}_{Q}_{sigma}'

G_file_list = glob.glob(os.path.join(root, '**/G_*.pkl'), recursive=True)
print(G_file_list)
G = atools.load_obj(G_file_list[0])
G = nx.convert_node_labels_to_integers(G)
df_nodes = dutils.convert_nodes_to_df(G)

print(f"Percentage of missing nodes: {df_nodes['id_str'].isnull().sum() / len(df_nodes) * 100:.2f}")
df_nodes.iloc[1]
# data = gconvert.from_networkx(G)
# %%
vocabulary = []
for i, G_raw_path in enumerate(G_file_list):
    print(f"{i} | {G_raw_path}")

    G = atools.load_obj(G_raw_path)
    G = G.to_directed() if not nx.is_directed(G) else G
    G = nx.convert_node_labels_to_integers(G)

    df_nodes = dutils.convert_nodes_to_df(G)
    for bow in df_nodes['bow']:
        if isinstance(bow, dict):
            vocab_n = list(bow.keys())
            vocabulary = atools.list_union(vocabulary, vocab_n)

# %%
from data_modules.citation_full import CitationFullDataModule

data_module = CitationFullDataModule(data_dir='../Data/graph', name='cora')

for batch in data_module.train_dataloader():
    print(batch)
    print(batch.x.max())
    print(batch.x.min())
# %% CitationFull
from torch_geometric.datasets import CitationFull, Planetoid
from torch_geometric.data import DataLoader

dataset = CitationFull(root='../Data/graph', name='CiteSeer')
print(len(dataset))
print(dataset[0])

data = dataset[0]
import torch_geometric.utils as gutils

edges = gutils.add_remaining_self_loops(data.edge_index)[0]

print(data.edge_index.shape)
print(edges.shape)
edges = gutils.add_remaining_self_loops(edges)[0]

print(data.edge_index.shape)
print(edges.shape)

# %%
# from torch_geometric.utils.convert import to_scipy_sparse_matrix
from torch_geometric.utils import to_scipy_sparse_matrix, dense_to_sparse, to_dense_adj

edges = to_dense_adj(dataset[0].edge_index)
edges = dense_to_sparse(dataset[0].edge_index)

edges = to_scipy_sparse_matrix(dataset[0].edge_index)
loader = DataLoader(dataset, batch_size=32, shuffle=True, drop_last=True)

# %%
dataset = Planetoid(root='../Data/graph', name='CiteSeer')
print(len(dataset))
print(dataset[0])

# %% Entities
from torch_geometric.datasets import Entities

dataset = Entities(root='../Data/graph', name='MUTAG')
print(len(dataset))
print(dataset[0])
# %% Twitter

from datasets.twitter import Twitter
import networkx as nx
from torch_geometric.utils import from_scipy_sparse_matrix

dataset = Twitter(root='../Data', use_cache=True,
                  transform=None, features='00000001')
# dataset.set_data()
# dataset.set_node_features(rate=True)
# dataset.set_data()
# print(dataset.data)
# out = nx.to_scipy_sparse_matrix(dataset.G)
#
# edge_index, edge_weight = from_scipy_sparse_matrix(out)

# %%
import os
import pandas as pd
import utils.args_parser as autils
import itertools

adj_file = os.path.join('../Data/TwitterLists/raw', 'df_adj.pkl')
df_adj = pd.read_pickle(adj_file)
feat_file = os.path.join('../Data/TwitterLists/raw', 'df_feat.pkl')
df_feat = pd.read_pickle(feat_file)
id_str_e_list = list(df_adj.id_str_e.unique())
id_str_s_list = list(df_adj['id_str_s'].unique())
id_str_e_list = list(df_adj['id_str_e'].unique())

id_str_inter = autils.list_intersection(id_str_s_list, id_str_e_list)
print(len(id_str_e_list))
id_str_e_list = autils.list_substract(id_str_e_list, id_str_s_list)
print(len(id_str_e_list))
combinations = list(itertools.product(id_str_s_list, id_str_e_list))
print(len(combinations))
print(len(id_str_e_list) * len(id_str_s_list))
# %%
feat_file = os.path.join('../Data/TwitterLists/raw', 'df_feat.pkl')
df_feat = pd.read_pickle(feat_file)
df_entries = pd.DataFrame.from_dict(combinations)
df_entries = df_entries.rename(columns={0: 'id_str_s', 1: 'id_str_e'})

df_entries.set_index(['id_str_s', 'id_str_e'], inplace=True)
df_entries['e_member_s'] = 0

df_entries['s_repliesto_b'] = 0
df_entries['s_likes_b'] = 0
df_entries['s_rted_b'] = 0
df_entries['s_follows_b'] = 0

# %%

df = df.rename(columns={0: 'id_str_s', 1: 'id_str_e'})
# %%
id_str_list = autils.list_union(id_str_e_list, id_str_s_list)
j = 0
entries_dict = {'id_str_s': [],
                'id_str_e': [],
                's_repliesto_b': [],
                's_likes_b': [],
                's_rted_b': [],
                's_follows_b': []}
for i, (_, user) in enumerate(df_feat.iterrows()):
    for set_col in [c for c in df_feat.columns if 'set_str' in c]:
        U_set = user[set_col]
        if not isinstance(U_set, str): continue

        U_list = U_set.split('___')
        U_inter = autils.list_intersection(id_str_e_list, U_list)
    if i % 1000 == 0: print(f"{i}/{len(df_feat)}")
# %%

df_adj_s = df_adj[df_adj.id_str_s == '1000034813531729921']
df_feats_s = dataset.df_feats.loc['1000034813531729921']
U_retweets = df_feats_s['retweet_U_set_str'].split('___')

print(len(U_retweets))
print(f"{df_adj_s.s_rted_b.sum()}")

list_inter = list_intersection(id_str_e_list, U_retweets)
print(len(list_inter))

# %%
import torch
import numpy as np

coo = dataset.corpus.tocoo()

indices = np.vstack((coo.row, coo.col))

X_sp = torch.sparse.LongTensor(torch.LongTensor(indices),
                               torch.FloatTensor(coo.data),
                               torch.Size(coo.shape))

for i, j in zip(X_sp.to_dense()[-1, :].type(torch.int), np.array(dataset.corpus.todense()[-1, :]).flatten()):
    assert i == j

# %%

from models.lda import LDA

lda = LDA(num_topics=10)
lda.prepare_data(df_complete=dataset.df_feats.sample(1000))
lda.train()

# %%
from sklearn.feature_extraction.text import CountVectorizer

vectorizer = CountVectorizer(stop_words='english', max_df=0.5, min_df=0.01)
corpus = vectorizer.fit_transform(list(dataset.df_feats.sample(1000)['text']))

# %%
from sklearn.decomposition import LatentDirichletAllocation

model = LatentDirichletAllocation(n_components=10, random_state=45)  # random state for reproducibility

model.fit(corpus)

# %%

from wordcloud import WordCloud

long_string = ','.join(list(df['text_processed'].values))
wordcloud = WordCloud(background_color="white", max_words=5000, contour_width=3, contour_color='steelblue')

wordcloud.generate(long_string)
wordcloud.to_image()

# %%
from torch_geometric.data import DataLoader

loader = DataLoader(dataset,
                    batch_size=10,
                    shuffle=False,
                    num_workers=0,
                    pin_memory=True,
                    )
for batch in loader:
    print(batch)
    print(batch.x.max())

# %%

from data_modules.twitter import TwitterDataModule

data_module = TwitterDataModule(data_dir='../Data',
                                normalize='std-log',
                                features='00000000')
data_module.prepare_data()
# %%
for batch in data_module.train_dataloader():
    print(batch)
    print(batch.x.max())
    print(batch.x.min())

# %%

dataset.set_node_features(rate=True)
# %%
dataset.set_data()
# %%

assert len(dataset.G.nodes()) == len(dataset.df_feats)
idx = 1023
node_id = list(dataset.G.nodes())[idx]
print(dataset.G.nodes[node_id]['favourites_count'])
print(dataset.df_feats.loc[node_id]['favourites_count'])

# %%

from datasets.toy_scm3 import ToySCM3
from datasets.utils import normalize_adj
import torch

dataset = ToySCM3()

dataset.prepare_data(normalize_A=None)

data = dataset.__getitem__(0)

print(dataset.edge_index)
print(dataset.edge_attr)

print(dataset.SCM_adj)
print(normalize_adj(dataset.SCM_adj, 'col'))

# %%
from torch_geometric.utils import dense_to_sparse
import torch

edge_index = dense_to_sparse(torch.tensor(dataset.SCM_adj))

# %%
import seaborn as sns
import pandas as pd

df = pd.DataFrame(data=dataset.X, columns=['x1', 'x2', 'x3'])

sns.pairplot(df)

# %%

from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader

dataset = TUDataset(root='./tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
