import os
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from model import Discriminator, Generator  # Assumes model definitions are in a separate file
import torch.nn as nn

# ---------------------------
# Configuration Parameters
# ---------------------------
LEARNING_RATE = 0.002          # Learning rate for optimizers
BATCH_SIZE = 256               # Batch size for training
EPOCHS = 200                   # Total number of training epochs
NOISE_CHANNELS = 3             # Dimensionality of the latent noise vector
NUM_SAMPLES = 25600            # Total number of samples in the synthetic dataset

# Use CUDA if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------------
# Model Initialization
# ---------------------------
gen_model = Generator().to(device)
disc_model = Discriminator().to(device)

criterion = nn.BCELoss()  # Loss for discriminator
mse_loss = nn.MSELoss()   # Loss for generator training in the particle model approach

# ---------------------------
# Synthetic Gaussian Mixture Dataset
# ---------------------------
class GaussianMixtureDataset(Dataset):
    """
    Dataset generating samples from a Gaussian mixture.
    
    Each component is a multivariate normal distribution with mean chosen from 
    the cartesian product of [1, -1] across the specified dimensions.
    """
    def __init__(self, num_samples=NUM_SAMPLES, num_dimensions=3):
        self.num_samples = num_samples
        self.num_dimensions = num_dimensions
        self.means_combinations = torch.cartesian_prod(
            *(torch.tensor([1, -1]) for _ in range(num_dimensions))
        ).float()
        self.data = self.generate_data(num_samples, num_dimensions)

    def generate_data(self, num_samples, num_dimensions):
        samples = []
        num_components = len(self.means_combinations)
        samples_per_component = int(num_samples / num_components)
        
        for i in range(num_components):
            distribution = torch.distributions.MultivariateNormal(
                loc=self.means_combinations[i],
                covariance_matrix=0.0125 * torch.eye(num_dimensions)
            )
            samples_temp = distribution.sample((samples_per_component,))
            samples.append(samples_temp)
        
        concatenated_samples = torch.cat(samples, dim=0)
        return concatenated_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx]


# Initialize the dataset and dataloader
dataset = GaussianMixtureDataset()
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

# ---------------------------
# Optimizers and Label Definitions
# ---------------------------
gen_optimizer = torch.optim.SGD(gen_model.parameters(), lr=LEARNING_RATE)
disc_optimizer = torch.optim.SGD(disc_model.parameters(), lr=LEARNING_RATE)

fake_label = 0
real_label = 1

# List to record Frobenius norms over time
record_frobenius_norm = []

# Directory to save generated images and metrics
generated_images_dir = './generated_images'
os.makedirs(generated_images_dir, exist_ok=True)

frobenius_norm_file = './frobenius_norm.txt'

# ---------------------------
# Main Training Loop
# ---------------------------
for epoch in range(EPOCHS):
    for i, imgs in enumerate(dataloader):
        # Move real data to device
        real_cpu = imgs.to(device)
        batch_size = real_cpu.size(0)

        # ---------------------------
        # Train the Generator (Particle Model Approach)
        # ---------------------------
        noise = torch.randn(batch_size, NOISE_CHANNELS, device=device)
        gen_imgs = gen_model(noise)
        
        # Detach generated images and enable gradient computation for particle update
        gen_imgs_t = gen_imgs.detach().clone().requires_grad_(True)
        score = disc_model(gen_imgs_t)
        score.backward(torch.ones_like(score), retain_graph=True)

        if gen_imgs_t.grad is not None:
            # Normalize the gradient direction for each sample
            direction_y = gen_imgs_t.grad / score.view(-1, 1).expand(-1, NOISE_CHANNELS).detach()
            for _ in range(5):
                gen_model.zero_grad()
                g_loss = mse_loss(direction_y + gen_imgs_t, gen_model(noise))
                g_loss.backward(retain_graph=True)
                gen_optimizer.step()

        # ---------------------------
        # Train the Discriminator
        # ---------------------------
        for _ in range(15):
            disc_model.zero_grad()

            # Train with real images
            output_real = disc_model(real_cpu)
            real_labels = torch.ones_like(output_real)
            real_disc_loss = criterion(output_real, real_labels)

            # Train with fake images generated by the generator
            noise = torch.randn(batch_size, NOISE_CHANNELS, device=device)
            fake = gen_model(noise)
            output_fake = disc_model(fake.detach())
            fake_labels = torch.zeros_like(output_fake)
            fake_disc_loss = criterion(output_fake, fake_labels)

            # Average loss over real and fake samples
            disc_loss = (real_disc_loss + fake_disc_loss) / 2
            disc_loss.backward(retain_graph=True)
            disc_optimizer.step()

        print(f'Epoch: {epoch} ===== Batch: {i}/{len(dataloader)}')

        # ---------------------------
        # Save Generated Images and Compute Frobenius Norm
        # ---------------------------
        if i % 100 == 0:
            # Generate a fixed batch of samples for visualization
            noise_vis = torch.randn(256, NOISE_CHANNELS, device=device)
            fake = gen_model(noise_vis)
            gen_imgs = fake.detach()

            # Plot the generated and original data in 3D scatter
            fig = plt.figure(figsize=(6, 6))
            ax1 = fig.add_subplot(111, projection='3d')
            ax1.scatter(gen_imgs.cpu()[:, 0],
                        gen_imgs.cpu()[:, 1],
                        gen_imgs.cpu()[:, 2],
                        label='Generated',
                        alpha=1,
                        edgecolors='darkblue')
            ax1.scatter(imgs[:, 0].cpu(),
                        imgs[:, 1].cpu(),
                        imgs[:, 2].cpu(),
                        label='Original',
                        alpha=1,
                        edgecolors=(0.8, 0.4, 0.0))
            ax1.view_init(elev=10)
            ax1.set_xlim([-1.5, 1.5])
            ax1.set_ylim([-1.5, 1.5])
            ax1.set_zlim([-1.5, 1.5])
            ax1.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
            ax1.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
            ax1.tick_params(axis='z', which='both', left=False, right=False, labelleft=False)
            plt.tight_layout()
            fig.savefig(os.path.join(generated_images_dir, f'epoch_{epoch}.pdf'), transparent=True)
            plt.close()

            # ---------------------------
            # Compute Frobenius Norm for Generated Samples
            # ---------------------------
            gen_imgs_t = gen_imgs.clone().requires_grad_(True)
            gen_imgs_t.grad = torch.zeros_like(gen_imgs_t)
            score = disc_model(gen_imgs_t).reshape(-1)
            score.backward(torch.ones_like(score))
            direction_y = gen_imgs_t.grad / score.view(-1, 1).expand(-1, NOISE_CHANNELS)

            # Compute Frobenius norm for each sample (reshaped to 2D)
            flat_direction_y = direction_y.view(256, -1)
            norms = torch.norm(flat_direction_y, p='fro', dim=1)
            frobenius_norm = torch.quantile(norms, q=0.875)

            record_frobenius_norm.append(float(frobenius_norm))

            # Write recorded Frobenius norms to a file
            with open(frobenius_norm_file, 'w') as file:
                for val in record_frobenius_norm:
                    file.write(f'Frobenius norm: {val}\n')
