import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

type_vector = './ast2features/word_vectors.json'
with open(type_vector, 'r') as f:
    type_content = json.load(f)


def read_json_files(directory):
    files_data = []
    file_names = []
    for filename in os.listdir(directory):
        if filename.endswith(".json"):
            with open(os.path.join(directory, filename), 'r', encoding='utf-8') as file:
                # print("input_file name is", os.path.join(directory, filename))
                files_data.append(json.load(file))
                file_names.append(filename[:-5] + '.sol')  # Extract basename without extension
    return files_data, file_names


def get_word2vec_vector(word):
    global type_content
    if word in type_content:
        return type_content[word]
    else:
        return [0] * 64


def extract_features(data):
    features = []

    # Extract simple features
    # features.append(len(data['name']))  # Name length as a feature

    # Extract Word2Vec embeddings for non-numeric features
    type_vector = get_word2vec_vector(data['type'])
    kind_vector = get_word2vec_vector(data['kind'])
    features.extend(type_vector)
    features.extend(kind_vector)

    features.append(data['subnode_num'])  # Number of subnodes

    # Extract subnode features
    subnode_features = []
    for subnode in data['sub_node_info']:
        subnode_features.append(subnode['node_is_function'])
        subnode_features.append(subnode['node_parameters'])
        subnode_features.append(subnode['node_return_parameters'])
        subnode_features.append(subnode['node_is_constructor'])
        subnode_features.append(subnode['node_is_variable'])
        subnode_features.append(subnode['node_variable_has_number'])
        # subnode_features.extend(get_word2vec_vector(subnode['node_type']))
        # subnode_features.extend(get_word2vec_vector(subnode['node_body_type']))
        # for subtype in subnode['sub_body_type']:
        #     subnode_features.extend(get_word2vec_vector(subtype))

        node_type_vector = get_word2vec_vector(subnode.get('node_type', ''))
        subnode_features.extend(node_type_vector)

        body_type_vector = get_word2vec_vector(subnode.get('node_body_type', ''))
        subnode_features.extend(body_type_vector)

        # Handle 'sub_body_type' list
        for subtype in subnode.get('sub_body_type', []):
            subtype_vector = get_word2vec_vector(subtype)
            subnode_features.extend(subtype_vector)

    features.extend(subnode_features)

    # Extract visibility features
    for v in data['sub_node_visibility']:
        if v == 'private':
            features.append(1)
        elif v == 'public':
            features.append(2)

    # Add the label
    features.append(data['y'])
    return features


def process_directory(directory):
    files_data, file_names = read_json_files(directory)
    all_features = []

    for data in files_data:
        features = extract_features(data)
        all_features.append(features)

    max_length = max(len(features) for features in all_features)

    for i in range(len(all_features)):
        feature_length = len(all_features[i])
        if feature_length < max_length:
            all_features[i].extend([0] * (max_length - feature_length))
        else:
            all_features[i] = all_features[i][:max_length]

    return all_features, file_names, max_length


class Autoencoder(nn.Module):
    def __init__(self, input_dim, encoding_dim):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, encoding_dim),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(encoding_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, input_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

    def encode(self, x):
        return self.encoder(x)


# Load and preprocess data
directory = './processed_ast'
features_list, contract_names, max_length = process_directory(directory)
features_array = np.array(features_list)
print(f"Max feature length: {max_length}")

# Convert data to PyTorch tensors
features_tensor = torch.tensor(features_array, dtype=torch.float32)

# Create DataLoader
batch_size = 128
dataset = TensorDataset(features_tensor, features_tensor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define model, loss function, and optimizer
input_dim = features_array.shape[1]
encoding_dim = 512
model = Autoencoder(input_dim, encoding_dim)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Train the model
num_epochs = 300
min_loss = float('inf')
consecutive_increasing_losses = 0

for epoch in range(num_epochs):
    for data in dataloader:
        inputs, _ = data
        outputs = model(inputs)
        loss = criterion(outputs, inputs)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    current_loss = loss.item()

    # Check for consecutive increasing losses
    if current_loss < min_loss:
        min_loss = current_loss
        consecutive_increasing_losses = 0
        best_model_state = model.state_dict()
        best_features = model.encode(features_tensor).detach().numpy()

    else:
        consecutive_increasing_losses += 1

    # Print and early stop if consecutive losses exceed 50
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {current_loss:.4f}')

    if consecutive_increasing_losses >= 150:
        print(f'Stopping early at epoch {epoch + 1} due to consecutive increasing losses.')
        break

# Save the best model and encoded features
best_features_list = best_features.tolist()  # Convert numpy array to list
best_features_dict = {contract_names[i]: [round(val, 8) for val in best_features_list[i]] for i in
                      range(len(contract_names))}
# best_features_dict = {contract_names[i]: best_features_list[i] for i in range(len(contract_names))}
output_dir = './ast2features'
model_dir = './model'
torch.save(best_model_state, os.path.join(model_dir, 'best_autoencoder_model.pth'))
with open(os.path.join(output_dir, 'best_encoded_features.json'), 'w') as f:
    json.dump(best_features_dict, f)

print("Encoded features of the best model have been saved.")
