import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
seed = 42  # You can set any value you want here
set_seed(seed)

def compute_mse(predicted, true):
    return np.mean((predicted - true) ** 2)

# Compute Mean Absolute Percentage Error (MAPE)
def compute_mape(predicted, true):
    return np.mean(np.abs((predicted - true) / true)) * 100


# Define the DCSNet model
class DCSNet(nn.Module):
    def __init__(self, input_dim, measurement_dim, hidden_dim):
        super(DCSNet, self).__init__()
        self.encoder = nn.Linear(measurement_dim, hidden_dim)
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        measurements = self.encoder(x)
        reconstructed = self.decoder(measurements)
        return reconstructed


# Create measurement matrix function
def create_measurement_matrix(t_low, t_high):
    A = np.zeros((len(t_low), len(t_high)))
    for i, t in enumerate(t_low):
        closest_index = np.argmin(np.abs(t_high - t))
        A[i, closest_index] = 1
    return A


# Generate high and low-resolution data
# Generate high and low-resolution data
from MFN_NODE.Interpolation.get_date_send import *
# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_200bus_9_10() # this is 200 bus data
# data_name = "200bus"

# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_PV_online() # this is PV data
# data_name = "PV"
#
# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_load_oncor() # this is load data
# data_name = "load"


# t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_ari_quality() # this is load data
# data_name = "air"

t_high, y_high, t_low, y_low, y_low_high_res, t_low_index, Interval = get_data_spiral() # this is load data
data_name = "spiral"

model_name = "DCS"


input_size = y_low.shape[1]

# print(t_high.shape, y_high.shape, t_low.shape, y_low.shape, y_low_high_res.shape, t_low_index.shape)
print(f"Rate: {len(t_low)/ len(t_high) * 100} %")

# Parameters
input_dim = len(t_high)
measurement_dim = len(t_low)
hidden_dim = 32

# Create model
model = DCSNet(input_dim, measurement_dim, hidden_dim)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Create the measurement matrix
A = create_measurement_matrix(t_low, t_high)
A_tensor = torch.tensor(A, dtype=torch.float32)

# Prepare the training data
y_train_low = torch.tensor(y_low.T, dtype=torch.float32)


# Training loop
know_y_high = True
num_epochs = 1000
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(y_train_low)
    if know_y_high:
        loss = criterion(outputs.T, torch.tensor(y_high))
    else:
        loss = criterion(outputs[:, t_low_index].T, torch.tensor(y_low))
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# Testing the model
model.eval()
with torch.no_grad():
    y_test_low = torch.tensor(y_low.T, dtype=torch.float32)
    dcs_reconstructed = model(y_test_low)
    print(dcs_reconstructed.shape)
    print(y_test_low.shape)
    dcs_reconstructed = dcs_reconstructed + (y_test_low[0, 0] - dcs_reconstructed[0, 0]) # ro
    # dcs_reconstructed = torch.tensor(dcs_reconstructed)
    dcs_reconstructed = dcs_reconstructed.T


print(dcs_reconstructed.shape )
print(y_low_high_res.shape )

### dcs results: -----------------------------------------------------------------
cs_mse = compute_mse(dcs_reconstructed.detach().numpy(), y_low_high_res)
cs_mape = compute_mape(dcs_reconstructed.detach().numpy(), y_low_high_res)
print(f'DCS: Low-Res MSE: {cs_mse:.4f}, Low-Res MAPE: {cs_mape:.4f}%')


print("dcs_res", dcs_reconstructed[:, 0].tolist())




for i in range(input_size):
    # Plot the results
    plt.figure(figsize=(10, 5))
    plt.plot(t_high, y_low_high_res[:, i], label=f'High-res signal {i} (true)', linestyle='dashed')
    plt.plot(t_low, y_low[:, i], 'o', label=f'Low-res signal {i} (observed)', color='blue', alpha=1)
    plt.plot(t_high, dcs_reconstructed[:, i], label=f'Interpolated signal {i}' )

    plt.title(f'{model_name} Interpolation', fontsize=14)
    plt.xlabel('Time', fontsize=14)
    plt.ylabel('Amplitude', fontsize=14)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.savefig(f'results/{model_name}_interpolation.pdf', format='pdf', dpi=300)



if data_name == "spiral":
    plt.figure(figsize=(6, 5))
    plt.plot(y_high[:, 0], y_low_high_res[:, 0], label=f'True {data_name}', linestyle='dashed')
    plt.plot(y_high[:, 0], dcs_reconstructed[:, 0], label=f'Interpolated {data_name}', linestyle='dashed')

    plt.title(f'{model_name} Interpolation 2d', fontsize=14)
    plt.xlabel('X', fontsize=14)
    plt.ylabel('Y', fontsize=14)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.savefig(f'results/{model_name}_{data_name}_interpolation_2d.pdf', format='pdf', dpi=300)

plt.show()