import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
seed = 42  # You can set any value you want here
set_seed(seed)


def compute_mse(predicted, true):
    return np.mean((predicted - true) ** 2)

# Compute Mean Absolute Percentage Error (MAPE)
def compute_mape(predicted, true):
    return np.mean(np.abs((predicted - true) / true)) * 100


# Define the neural network
class SemiSupervisedNN(nn.Module):
    def __init__(self, output_size):
        super(SemiSupervisedNN, self).__init__()
        self.fc1 = nn.Linear(1, 64)
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, output_size)
        self.relu = nn.ReLU()

    def forward(self, t):
        # x = torch.cat((t.unsqueeze(-1)), dim=1)
        x = t.unsqueeze(-1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.fc4(x)
        return x


# Function to train the model
def train_model(model, criterion, optimizer, t_low, y_low, epochs=10000):
    for epoch in range(epochs):
        model.train()
        outputs = model(t_low) # input [t_low] and [y_high in low]
        loss = criterion(outputs, y_low)   # compare: predicted y_low vs true y_low
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (epoch + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')


from MFN_NODE.Interpolation.get_date_send import *
# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_200bus_9_10()
# data_name = "200bus"

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_PV_online() # this is PV data
# data_name = "PV"

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_load_oncor() # this is load data
# data_name = "load"

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_ari_quality() # this is load data
# data_name = "air"


t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_spiral() # this is load data
data_name = "spiral"

model_name = "Simi-NN"
input_size = y_low.shape[1]
output_size = y_low.shape[1]
print(input_size)
print(f"Rate: {len(t_low)/ len(t_high) * 100} %")


# section Prepare labeled data
t_low = torch.tensor(t_low, dtype=torch.float32)                     # t_low is seen as one feature
y_low = torch.tensor(y_low, dtype=torch.float32)                     # y_low is seen as label

# Initialize the model, loss function, and optimizer
model = SemiSupervisedNN(output_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

#section Train initial model
train_model(model, criterion, optimizer, t_low, y_low)

# section Generate pseudo-labels
# model.eval()
# with torch.no_grad():
#     unlabeled_indices = ~ t_low_index
#     t_unlabeled = torch.tensor(t_high[unlabeled_indices], dtype=torch.float32)
#     pseudo_labels = model(t_unlabeled)   # pseudo_labels is the predicted low data, which is not in y_lo

with torch.no_grad():
    # Convert t_high and t_low to numpy arrays for indexing
    t_high_np = t_high.numpy() if isinstance(t_high, torch.Tensor) else t_high
    t_low_np = t_low.numpy() if isinstance(t_low, torch.Tensor) else t_low

    # Create a boolean mask where True indicates elements in t_high not in t_low
    unlabeled_indices = ~np.isin(t_high_np, t_low_np)
    # Convert boolean mask to tensor for indexing
    unlabeled_indices_tensor = torch.from_numpy(unlabeled_indices)
    # Get the unlabeled data
    t_unlabeled = t_high[unlabeled_indices_tensor]
    t_unlabeled = torch.tensor(t_unlabeled, dtype=torch.float32)

    # Get pseudo-labels from the model
    pseudo_labels = model(t_unlabeled)

# section Combine labeled and pseudo-labeled data
combined_t = torch.cat((t_low, t_unlabeled), dim=0)

# combined_y = torch.cat((y_high_low, y_unlabeled), dim=0)
combined_labels = torch.cat((y_low, pseudo_labels), dim=0)

#section  Re-train the model
train_model(model, criterion, optimizer, combined_t, combined_labels)

model.eval()
with torch.no_grad():
    interpolated_low = model(torch.tensor(t_high, dtype=torch.float32))

SemiNN_interpolated_low = interpolated_low.numpy()


### results: -----------------------------------------------------------------
Semi_mse = compute_mse(SemiNN_interpolated_low, y_low_high_res)
Semi_mape = compute_mape(SemiNN_interpolated_low, y_low_high_res)
print(f'SemiNN: Low-Res MSE: {Semi_mse:.4f}, Low-Res MAPE: {Semi_mape:.4f}%')


print("SemiNN_res", SemiNN_interpolated_low[:, 0].tolist())


# plt.figure(figsize=(10, 5))
# plt.plot(t_high, y_low_high_res[:, 0], label='High-res signal 1 (true)',  linestyle='dashed', c="blue")
# plt.plot(t_high, y_low_high_res[:, 1], label='High-res signal 2 (true)', linestyle='dashed', c="orange")
#
# plt.plot(t_low, y_low[:, 0], 'o', label='Low-res signal 1 (observed)', color='blue', alpha=1, )
# plt.plot(t_low, y_low[:, 1], 'o', label='Low-res signal 2 (observed)', color='orange', alpha=1, )
#
# plt.plot(t_high, SemiNN_interpolated_low[:, 0], label='Interpolated signal 1', linestyle='solid', c="blue")
# plt.plot(t_high, SemiNN_interpolated_low[:, 1], label='Interpolated signal 2', linestyle='solid', c="orange")
#
#
# if_use_HR = False
# plt.title("SemiDNN (if use HR? {}) interpolation".format(if_use_HR), fontsize=14)
# plt.xlabel('Time', fontsize=14)
# plt.ylabel('Amplitude', fontsize=14)
# plt.legend(fontsize=12)
# plt.xticks(fontsize=14)
# plt.yticks(fontsize=14)
# plt.savefig('results/SemiDNN (if use HR? {}) interpolation.pdf'.format(if_use_HR), format='pdf', dpi=300)
# plt.show()


plt.figure(figsize=(10, 5))
for i in range(input_size):
    # Plot the results
    plt.plot(t_high, y_low_high_res[:, i], label=f'High-res signal {i} (true)', linestyle='dashed')
    plt.plot(t_low, y_low[:, i], 'o', label=f'Low-res signal {i} (observed)', color='blue', alpha=1)
    plt.plot(t_high, SemiNN_interpolated_low[:, i], label=f'Interpolated signal {i}' )

    plt.title(f'{model_name} Interpolation', fontsize=14)
    plt.xlabel('Time', fontsize=14)
    plt.ylabel('Amplitude', fontsize=14)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.savefig(f'results/{model_name}_{data_name}_interpolation.pdf', format='pdf', dpi=300)
# plt.show()


if data_name == "spiral":
    plt.figure(figsize=(6, 5))
    plt.plot(y_high, y_low_high_res, label=f'True {data_name}', linestyle='dashed')
    plt.plot(y_high, SemiNN_interpolated_low, label=f'Interpolated {data_name}', linestyle='dashed')

    plt.title(f'{model_name} Interpolation 2d', fontsize=14)
    plt.xlabel('X', fontsize=14)
    plt.ylabel('Y', fontsize=14)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.savefig(f'results/{model_name}_{data_name}_interpolation_2d.pdf', format='pdf', dpi=300)

plt.show()