from __future__ import print_function
import os
import numpy as np
import csv
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision.utils import save_image
from pytorch_fid import fid_score

from model import Generator, Discriminator  # Import model definitions
from config import get_config  # Import configuration settings
import utils  # Import utility functions

# ---------------------------
# Setup and Configuration
# ---------------------------
opt = get_config()
utils.set_random_seed(opt.manualSeed)
eval_num = opt.eval_num
division = 50

# Define generic base directory and subdirectories
base_dir = "/path/to/base_directory"
ckpt_dir = os.path.join(base_dir, "checkpoints", opt.dataset)
real_image_dir = os.path.join(base_dir, "real_images", opt.dataset)
fake_image_dir = os.path.join(base_dir, "fake_images", opt.dataset)
metrics_dir = os.path.join(base_dir, "metrics", opt.dataset)
generated_image_dir = os.path.join(base_dir, "generated_images", opt.dataset)

# Create directories if they do not exist
for directory in [
    base_dir,
    generated_image_dir,
    ckpt_dir,
    real_image_dir,
    fake_image_dir,
    metrics_dir,
]:
    os.makedirs(directory, exist_ok=True)

# Create CSV file for logging metrics
csv_file_path = os.path.join(metrics_dir, f"{opt.dataset}_metrics.csv")
with open(csv_file_path, mode="w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(
        [
            "Epoch",
            "FID Score",
            "Frobenius Norm",
            "Steepness",
            "d_loss",
            "g_loss",
            "vanilla_dg",
            "local_random_dg",
        ]
    )

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: CUDA device available; consider running with --cuda")

# ---------------------------
# Dataset Setup
# ---------------------------
if opt.dataset == "cifar10":
    dataset = dset.CIFAR10(
        root=opt.dataroot,
        download=True,
        transform=transforms.Compose(
            [
                transforms.Resize(64),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        ),
    )
    nc = 3
elif opt.dataset == "fashionmnist":
    dataset = dset.FashionMNIST(
        root=opt.dataroot,
        download=True,
        transform=transforms.Compose(
            [
                transforms.Resize(64),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        ),
    )
    nc = 1
elif opt.dataset == "mnist":
    dataset = dset.MNIST(
        root=opt.dataroot,
        download=True,
        transform=transforms.Compose(
            [
                transforms.Resize(64),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        ),
    )
    nc = 1
else:
    raise ValueError("Unsupported dataset")

dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)
)

# ---------------------------
# Device and Model Setup
# ---------------------------
device = torch.device("cuda" if opt.cuda else "cpu")
ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)


def weights_init(m):
    """Custom weights initialization."""
    classname = m.__class__.__name__
    if "Conv" in classname:
        m.weight.data.normal_(0.0, 0.02)
    elif "BatchNorm" in classname:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


# Initialize Generator and Discriminator
netG = Generator(ngpu, nc, nz, ngf).to(device)
netG.apply(weights_init)
if opt.netG != "":
    netG.load_state_dict(torch.load(opt.netG))
print(netG)

netD = Discriminator(ngpu, nc, ndf, noise_std=opt.noise_std).to(device)
netD.apply(weights_init)
if opt.netD != "":
    netD.load_state_dict(torch.load(opt.netD))
print(netD)

criterion = nn.BCELoss()
real_label = 1
fake_label = 0

# Setup optimizers for main networks
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))


# ---------------------------
# Duality Gap Calculation Functions
# ---------------------------
def dg_score(generator, discriminator):
    """Calculate the duality gap score using a fixed number of evaluation steps."""
    criterion = nn.BCELoss()
    scores = []
    eval_iterator = iter(dataloader)
    for _ in range(20):
        x = next(eval_iterator)[0].to(device)
        noise = torch.randn(opt.batchSize, nz, 1, 1).to(device)
        fake = generator(noise)
        output_real = discriminator(x).squeeze()
        output_fake = discriminator(fake).squeeze()
        disc_loss = criterion(
            output_real,
            torch.full((opt.batchSize,), 1, device=device, dtype=output_real.dtype),
        ) + criterion(
            output_fake,
            torch.full((opt.batchSize,), 0, device=device, dtype=output_fake.dtype),
        )
        scores.append(-disc_loss.item())
    return np.mean(scores)


def compute_duality_gap(netG, netD, dg_gen_model, dg_disc_model):
    """Compute the duality gap using both vanilla and local random approaches."""
    criterion = nn.BCELoss()
    epochs = 30
    progbar = range(epochs)

    # Copy weights from main models to duality gap models
    with torch.no_grad():
        for p, w in zip(dg_gen_model.parameters(), netG.parameters()):
            p.copy_(w)
        for p, w in zip(dg_disc_model.parameters(), netD.parameters()):
            p.copy_(w)

    # --- Vanilla Approach ---
    # Update generator of duality gap model
    for _ in progbar:
        noise = torch.randn(opt.batchSize, nz, 1, 1).to(device)
        fake = dg_gen_model(noise)
        output = dg_disc_model(fake)
        gen_loss = criterion(
            output, torch.full((opt.batchSize,), 1, device=device, dtype=output.dtype)
        )
        dg_gen_optimizer.zero_grad()
        gen_loss.backward()
        dg_gen_optimizer.step()

    M_u_worst_v = dg_score(dg_gen_model, dg_disc_model)

    # Reset weights
    with torch.no_grad():
        for p, w in zip(dg_gen_model.parameters(), netG.parameters()):
            p.copy_(w)
        for p, w in zip(dg_disc_model.parameters(), netD.parameters()):
            p.copy_(w)

    # Update discriminator of duality gap model
    for _ in progbar:
        real_images = next(iter(dataloader))[0].to(device)
        noise = torch.randn(opt.batchSize, nz, 1, 1).to(device)
        fake = dg_gen_model(noise)
        disc_loss = criterion(
            dg_disc_model(real_images),
            torch.full((opt.batchSize,), 1, device=device, dtype=output.dtype),
        ) + criterion(
            dg_disc_model(fake),
            torch.full((opt.batchSize,), 0, device=device, dtype=output.dtype),
        )
        dg_disc_optimizer.zero_grad()
        disc_loss.backward()
        dg_disc_optimizer.step()

    M_u_v_worst = dg_score(dg_gen_model, dg_disc_model)
    DG["vanilla"].append(abs(M_u_v_worst - M_u_worst_v))

    # --- Local Random Approach ---
    # Perturb discriminator weights with random noise
    random_weight_init = [w + torch.randn_like(w) * 0.01 for w in netD.parameters()]
    with torch.no_grad():
        for p, w in zip(dg_gen_model.parameters(), netG.parameters()):
            p.copy_(w)
        for p, w in zip(dg_disc_model.parameters(), random_weight_init):
            p.copy_(w)

    # Update generator
    for _ in progbar:
        noise = torch.randn(opt.batchSize, nz, 1, 1).to(device)
        fake = dg_gen_model(noise)
        output = dg_disc_model(fake)
        gen_loss = criterion(
            output, torch.full((opt.batchSize,), 1, device=device, dtype=output.dtype)
        )
        dg_gen_optimizer.zero_grad()
        gen_loss.backward()
        dg_gen_optimizer.step()

    M_u_worst_v_local = dg_score(dg_gen_model, dg_disc_model)

    # Perturb generator weights with random noise
    random_weight_init = [w + torch.randn_like(w) * 0.01 for w in netG.parameters()]
    with torch.no_grad():
        for p, w in zip(dg_gen_model.parameters(), random_weight_init):
            p.copy_(w)
        for p, w in zip(dg_disc_model.parameters(), netD.parameters()):
            p.copy_(w)

    # Update discriminator
    for _ in progbar:
        real_images = next(iter(dataloader))[0].to(device)
        noise = torch.randn(opt.batchSize, nz, 1, 1).to(device)
        fake = dg_gen_model(noise)
        disc_loss = criterion(
            dg_disc_model(real_images),
            torch.full((opt.batchSize,), 1, device=device, dtype=output.dtype),
        ) + criterion(
            dg_disc_model(fake),
            torch.full((opt.batchSize,), 0, device=device, dtype=output.dtype),
        )
        dg_disc_optimizer.zero_grad()
        disc_loss.backward()
        dg_disc_optimizer.step()

    M_u_v_worst_local = dg_score(dg_gen_model, dg_disc_model)
    DG["local_random"].append(abs(M_u_v_worst_local - M_u_worst_v_local))


# ---------------------------
# Main Training Loop
# ---------------------------
def main():
    # Save real images for evaluation
    count = 0
    for images, _ in dataloader:
        for img in images:
            if count >= eval_num:
                break
            save_image(
                img,
                os.path.join(real_image_dir, f"real_{count}.jpg"),
                normalize=True,
            )
            count += 1

    # Initialize records for metrics
    record_steepness = []
    record_fid_score = []
    record_largest_frobenius_norm = []
    global DG
    DG = {"vanilla": [], "local_random": []}

    for epoch in range(opt.niter):
        # Training epoch loop with progress bar
        with tqdm(
            enumerate(dataloader),
            total=len(dataloader),
            desc=f"Epoch {epoch+1}/{opt.niter}",
        ) as pbar:
            for _, data in pbar:
                # ---------------------------
                # Train Discriminator
                # ---------------------------
                for _ in range(1):
                    netD.zero_grad()
                    real_cpu = data[0].to(device)
                    batch_size = real_cpu.size(0)

                    # Train with real data
                    output = netD(real_cpu)
                    label = torch.full(
                        (batch_size,), real_label, device=device, dtype=output.dtype
                    )
                    errD_real = criterion(output, label)
                    errD_real.backward()
                    D_x = output.mean().item()

                    # Train with fake data
                    noise = torch.randn(batch_size, nz, 1, 1, device=device)
                    fake = netG(noise)
                    label.fill_(fake_label)
                    output = netD(fake.detach())
                    errD_fake = criterion(output, label)
                    errD_fake.backward()
                    D_G_z1 = output.mean().item()

                    optimizerD.step()

                errD = errD_real + errD_fake

                # ---------------------------
                # Train Generator
                # ---------------------------
                netG.zero_grad()
                label.fill_(real_label)  # Use real labels for generator cost
                output = netD(fake)
                errG = criterion(output, label)
                errG.backward()
                D_G_z2 = output.mean().item()

                optimizerG.step()

                # Update progress bar
                pbar.set_postfix(
                    {
                        "Loss_D": f"{errD.item():.4f}",
                        "Loss_G": f"{errG.item():.4f}",
                        "D(x)": f"{D_x:.4f}",
                        "D(G(z))": f"{D_G_z1:.4f} / {D_G_z2:.4f}",
                    }
                )

        # Save model checkpoints after each epoch
        torch.save(netG.state_dict(), os.path.join(ckpt_dir, f"netG_epoch_{epoch}.pth"))
        torch.save(netD.state_dict(), os.path.join(ckpt_dir, f"netD_epoch_{epoch}.pth"))

        # Create duality gap models and optimizers
        dg_gen_model = Generator(1, nc, nz, ngf).to(device)
        dg_disc_model = Discriminator(1, nc, ndf, noise_std=opt.noise_std).to(device)
        global dg_gen_optimizer, dg_disc_optimizer
        dg_gen_optimizer = optim.Adam(
            dg_gen_model.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)
        )
        dg_disc_optimizer = optim.Adam(
            dg_disc_model.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)
        )

        # Compute duality gap metrics
        compute_duality_gap(netG, netD, dg_gen_model, dg_disc_model)

        # Generate fake images and save them
        noise = torch.randn(eval_num, nz, 1, 1, device=device)
        with torch.no_grad():
            fake = netG(noise)
            save_image(
                fake[:64],
                os.path.join(generated_image_dir, f"epoch_{epoch}.jpg"),
                nrow=8,
                normalize=True,
                pad_value=1,
            )
            for j, img in enumerate(fake):
                save_image(
                    img,
                    os.path.join(fake_image_dir, f"fake_{j}.jpg"),
                    normalize=True,
                )

        # Calculate additional metrics
        real_images = next(iter(dataloader))[0].to(device)
        d_loss, g_loss = utils.calculate_loss(
            netG, netD, real_images, noise[: real_images.size(0)], device
        )
        print(f"Epoch {epoch} - D Loss: {d_loss}, G Loss: {g_loss}")

        print(f"Calculating metrics for Epoch {epoch}...")
        print("Calculating FID...")
        fid_score_value = fid_score.calculate_fid_given_paths(
            [real_image_dir, fake_image_dir],
            opt.batchSize,
            device,
            dims=2048,
        )
        record_fid_score.append(fid_score_value)

        print("Calculating Steepness...")
        steepness_value = utils.calculate_steepness(netG, noise[: eval_num // division])
        record_steepness.append(steepness_value)

        print("Calculating Frobenius Norm...")
        largest_frobenius_norm = utils.calculate_frobenius_norm(
            netD, fake[: eval_num // division]
        )
        record_largest_frobenius_norm.append(largest_frobenius_norm)

        # Save metrics to CSV
        utils.save_metrics_to_csv(
            epoch,
            fid_score_value,
            largest_frobenius_norm,
            steepness_value,
            d_loss,
            g_loss,
            DG["vanilla"][-1],
            DG["local_random"][-1],
            save_path=csv_file_path,
        )


if __name__ == "__main__":
    main()
