import os
import json
import glob
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import copy

DEVICE = os.environ.get("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")


def load_matrix(path, max_items=None):
    subjects, items_set, responses = [], set(), []
    with open(path, "r", encoding="utf-8") as f:
        for ln in f:
            d = json.loads(ln)
            subjects.append(d["subject_id"])
            responses.append(d["responses"])
            items_set.update(d["responses"])
    items = sorted(items_set, key=lambda x: int(x.split("_")[-1]))
    if max_items:
        items = items[:max_items]
    idx = {iid: i for i, iid in enumerate(items)}
    mat = np.full((len(subjects), len(items)), np.nan)
    for i, resp in enumerate(responses):
        for iid, val in resp.items():
            if iid in idx:
                mat[i, idx[iid]] = val
    return mat, items


class ResponseDataset(Dataset):
    def __init__(self, matrix):
        self.item_means = np.nanmean(matrix, axis=0)
        self.data = np.nan_to_num(matrix, nan=self.item_means).astype(np.float32)
        self.mask = ~np.isnan(matrix)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        x = self.data[idx]
        m = self.mask[idx].astype(np.float32)
        original_x = np.nan_to_num(self.data[idx], nan=-1).astype(np.float32)
        return torch.tensor(x), torch.tensor(m), torch.tensor(original_x)


class VAE(nn.Module):
    def __init__(self, input_dim=200, latent_dim=32):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(64, latent_dim)
        self.fc_logvar = nn.Linear(64, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, input_dim),
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


def train_vae(train_mat, val_mat, input_dim, latent_dim, epochs, batch_size, beta, patience):
    train_dataset = ResponseDataset(train_mat)
    val_dataset = ResponseDataset(val_mat)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = VAE(input_dim, latent_dim).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    best_val_loss = float("inf")
    epochs_no_improve = 0
    best_model_state = None

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        kld_weight = beta * min(1.0, epoch / (epochs * 0.25))

        for x, mask, original_x in train_loader:
            x, mask, original_x = x.to(DEVICE), mask.to(DEVICE), original_x.to(DEVICE)
            x_hat, mu, logvar = model(x)
            bce = F.binary_cross_entropy(x_hat, (original_x == 1).float(), reduction="none")
            bce = (bce * mask).sum() / mask.sum()
            kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
            loss = bce + kld_weight * kld
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for x, mask, original_x in val_loader:
                x, mask, original_x = x.to(DEVICE), mask.to(DEVICE), original_x.to(DEVICE)
                x_hat, mu, logvar = model(x)
                bce = F.binary_cross_entropy(x_hat, (original_x == 1).float(), reduction="none")
                bce = (bce * mask).sum() / mask.sum()
                kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
                val_loss = bce + beta * kld
                total_val_loss += val_loss.item()

        avg_train_loss = total_train_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(val_loader)
        print(f"Epoch {epoch + 1}/{epochs} - Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | KLD Weight: {kld_weight:.2f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            best_model_state = copy.deepcopy(model.state_dict())
            print(f"New best model. Val Loss: {best_val_loss:.4f}")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {patience} epochs without improvement.")
            break

    model.load_state_dict(best_model_state)
    return model


def generate_synthetic(model, n_samples, latent_dim, input_dim, items):
    model.eval()
    output = []
    print("Generating synthetic samples...")
    with torch.no_grad():
        for i in range(n_samples):
            z = torch.randn(1, latent_dim).to(DEVICE)
            probs = model.decode(z).squeeze()
            binary_tensor = torch.bernoulli(probs)
            binary = binary_tensor.cpu().numpy().astype(int)
            response = {items[j]: int(binary[j]) for j in range(input_dim)}
            sample = {"subject_id": f"vae_{i}", "responses": response}
            output.append(sample)
    return output


def main(args):
    input_path = args.input_path or os.environ.get("INPUT_PATH", "")
    output_path = args.output_path or os.environ.get("OUTPUT_PATH", "aug_vae.jsonl")
    n_items = args.n_items if args.n_items is not None else int(os.environ.get("N_ITEMS", "2000"))
    latent_dim = args.latent_dim if args.latent_dim is not None else int(os.environ.get("LATENT_DIM", "32"))
    epochs = args.epochs if args.epochs is not None else int(os.environ.get("EPOCHS", "100"))
    batch_size = args.batch_size if args.batch_size is not None else int(os.environ.get("BATCH_SIZE", "8"))
    n_synth = args.n_synthetic if args.n_synthetic is not None else int(os.environ.get("N_SYNTHETIC", "300"))
    val_split = args.validation_split if args.validation_split is not None else float(os.environ.get("VALIDATION_SPLIT", "0.2"))
    patience = args.patience if args.patience is not None else int(os.environ.get("PATIENCE", "10"))
    beta = args.beta if args.beta is not None else float(os.environ.get("BETA", "0.5"))

    if not input_path or not os.path.exists(input_path):
        raise FileNotFoundError(f"Input path not found: {input_path}")

    mat, items = load_matrix(input_path, max_items=n_items)
    train_mat, val_mat = train_test_split(mat, test_size=val_split, random_state=42)
    print(f"Data loaded. Train: {train_mat.shape[0]}, Val: {val_mat.shape[0]}")

    model = train_vae(
        train_mat=train_mat,
        val_mat=val_mat,
        input_dim=n_items,
        latent_dim=latent_dim,
        epochs=epochs,
        batch_size=batch_size,
        beta=beta,
        patience=patience,
    )

    samples = generate_synthetic(
        model=model,
        n_samples=n_synth,
        latent_dim=latent_dim,
        input_dim=n_items,
        items=items,
    )

    with open(output_path, "w", encoding="utf-8") as f:
        for s in samples:
            f.write(json.dumps(s) + "\n")



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input-path")
    parser.add_argument("--output-path")
    parser.add_argument("--n-items", type=int)
    parser.add_argument("--latent-dim", type=int)
    parser.add_argument("--epochs", type=int)
    parser.add_argument("--batch-size", type=int)
    parser.add_argument("--n-synthetic", type=int)
    parser.add_argument("--validation-split", type=float)
    parser.add_argument("--patience", type=int)
    parser.add_argument("--beta", type=float)
    args = parser.parse_args()
    main(args)
