!pip install torch torchvision

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import numpy as np

# Assuming you have a custom Dataset class called PowerDataset
class PowerDataset(Dataset):
    # Initialize your data, download, etc.
    def __init__(self, data_dir):
        self.data_dir = data_dir
        # Load your dataset here

    # Total number of samples
    def __len__(self):
        # Return the number of samples
        pass

    # Get a sample from the dataset
    def __getitem__(self, idx):
        # Load data and get label
        pass

# data directory
data_dir = '/content/drive/MyDrive/data'

# Instantiate the dataset
power_dataset = PowerDataset(data_dir)

# Create a DataLoader
batch_size = 4  # Set your desired batch size
power_dataloader = DataLoader(power_dataset, batch_size=batch_size, shuffle=True)

# Define the Generator architecture
class Generator(nn.Module):
    def __init__(self, condition_dim, noise_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(condition_dim + noise_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, output_dim),
            nn.Tanh()
        )

    def forward(self, condition, noise):
        x = torch.cat((condition, noise), -1)
        return self.model(x)

# Define the Discriminator architecture
class Discriminator(nn.Module):
    def __init__(self, condition_dim, input_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(condition_dim + input_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, condition, input_data):
        x = torch.cat((condition, input_data), -1)
        return self.model(x)

# Function to initialize the weights of the model
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)

# Hyperparameters
condition_dim = 2  # Example: x, y coordinates
noise_dim = 100    # Size of the noise vector
output_dim = 1     # Example: single power value

# Create the Generator and Discriminator instances
generator = Generator(condition_dim, noise_dim, output_dim)
discriminator = Discriminator(condition_dim, output_dim)

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Move models to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

# Binary Cross Entropy Loss for both the Generator and Discriminator
adversarial_loss = torch.nn.BCELoss().to(device)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Physical constraints loss function
def physical_constraints_loss(generated_data, condition):
    # Hardcoded physical constraints values for demonstration purposes
    direct_loss = torch.mean((generated_data - 1.0)**2)  # Example value， detailed loss is caculated offline in DCEM
    reflection_loss = torch.mean((generated_data - 0.5)**2)  # Example value
    diffraction_loss = torch.mean((generated_data - 0.2)**2)  # Example value
    
    total_loss = direct_loss + reflection_loss + diffraction_loss
    return total_loss

# Training loop
num_epochs = 50
real_label = 1.
fake_label = 0.

for epoch in range(num_epochs):
    for i, (locations, powers) in enumerate(power_dataloader):

        # (1) Update Discriminator
        locations, powers = locations.to(device), powers.to(device)
        discriminator.zero_grad()

        # Train with real data
        real_data = powers.view(-1, 1)
        label = torch.full((locations.size(0),), real_label, dtype=torch.float, device=device)
        output = discriminator(locations, real_data).view(-1)
        errD_real = adversarial_loss(output, label)
        errD_real.backward()

        # Train with fake data
        noise = torch.randn(locations.size(0), noise_dim, device=device)
        fake_data = generator(locations, noise)
        label.fill_(fake_label)
        output = discriminator(locations, fake_data.detach()).view(-1)
        errD_fake = adversarial_loss(output, label)
        errD_fake.backward()
        optimizer_D.step()

        # (2) Update Generator
        generator.zero_grad()
        label.fill_(real_label)
        output = discriminator(locations, fake_data).view(-1)
        errG_adv = adversarial_loss(output, label)

        # Calculate physical constraints loss
        errG_phy = physical_constraints_loss(fake_data, locations)
        
        # Total generator loss
        errG = errG_adv + 0.1 * errG_phy  # Adjust the weight of physical constraints loss as needed
        errG.backward()
        optimizer_G.step()

        if i % 50 == 0:
            print(f'[{epoch}/{num_epochs}][{i}/{len(power_dataloader)}]\tLoss_D: {errD_real + errD_fake:.4f}\tLoss_G: {errG:.4f}\tD(x): {errD_real.mean():.4f}\tD(G(z)): {errD_fake.mean():.4f}')

# Save the models
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')
