import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from margflow.nn.cond_mlp import Hypernetwork


class ContinuousConditionalDataset(Dataset):
    def __init__(self, num_samples, input_dim, condition_dim):
        super().__init__()
        self.num_samples = num_samples
        self.input_dim = input_dim
        self.condition_dim = condition_dim

        # Randomly generate input features
        self.inputs = torch.rand(num_samples, input_dim) * 10  # Inputs in range [0, 10]

        # Generate continuous condition vectors in the range [0, 1]
        self.conditions = torch.rand(num_samples, condition_dim)

        # Generate outputs based on the condition vector
        self.outputs = self.generate_targets()

    def generate_targets(self):
        outputs = []
        for i in range(self.num_samples):
            x = self.inputs[i]
            condition = self.conditions[i]

            # Example nonlinear dependency on the condition vector
            scale = torch.sin(condition * torch.pi)  # Scales based on condition
            shift = torch.sum(condition)  # Shift based on sum of condition values

            # Target: y = scale * sum(x) + shift
            y = torch.sum(x * scale) + shift
            outputs.append(y.unsqueeze(0))

        return torch.cat(outputs)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.inputs[idx], self.conditions[idx], self.outputs[idx]


def visualize_predictions(model, data_loader):
    system.eval()

    inputs, conditions, targets = next(iter(data_loader))

    with torch.no_grad():
        predictions = model(conditions, inputs)

    plt.figure(figsize=(8, 6))
    plt.scatter(targets.numpy(), predictions.numpy(), alpha=0.7)
    plt.plot(
        [targets.min(), targets.max()], [targets.min(), targets.max()], "r--", label="Ideal Fit"
    )
    plt.xlabel("True Outputs")
    plt.ylabel("Predicted Outputs")
    plt.title("True vs Predicted Outputs")
    plt.legend()
    plt.show()


if __name__ == "__main__":

    input_dim = 5
    output_dim = 1  # Single scalar output
    hidden_dim = 128
    target_n_layers = 4
    condition_dim = 1  # Dimensionality of continuous condition vector
    hypernet_n_layers = 3
    hypernet_hidden_dim = 128
    num_samples = 1000
    batch_size = 32
    epochs = 20
    learning_rate = 1e-3

    train_dataset = ContinuousConditionalDataset(num_samples, input_dim, condition_dim)
    test_dataset = ContinuousConditionalDataset(num_samples // 2, input_dim, condition_dim)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    system = Hypernetwork(
        # Condition vector is the hypernetwork input
        in_dim=input_dim,
        hid_dim=hidden_dim,
        n_layers=target_n_layers,
        out_dim=output_dim,
        hypernet_in_dim=condition_dim,
        hypernet_n_layers=hypernet_n_layers,
        hypernet_hid_dim=hypernet_hidden_dim,
        skip_connection=True,  # Enable skip connection
    )

    optimizer = optim.Adam(system.parameters(), lr=learning_rate)
    loss_fn = nn.MSELoss()

    def train():
        system.train()
        losses = []
        try:
            for epoch in range(epochs):
                epoch_loss = 0.0
                for inputs, conditions, targets in train_loader:
                    optimizer.zero_grad()
                    outputs = system(x=inputs, context=conditions)
                    loss = loss_fn(outputs.squeeze(), targets)
                    epoch_loss += loss.item()
                    loss.backward()
                    optimizer.step()

                avg_loss = epoch_loss / len(train_loader)
                losses.append(avg_loss)
                print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")
        except KeyboardInterrupt:
            pass

        return losses

    losses = train()

    plt.figure(figsize=(8, 5))
    plt.plot(range(1, len(losses) + 1), losses, label="Training Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training Loss Curve")
    plt.legend()
    plt.show()

    visualize_predictions(model=system, data_loader=test_loader)
