#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse, os, sys, math
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import pennylane as qml
from pennylane import numpy as pnp
from sklearn.metrics import confusion_matrix
import itertools
from PIL import Image

# ---- 16 pt everywhere ----
matplotlib.rcParams.update({
    "font.size": 16,
    "axes.labelsize": 16,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "legend.fontsize": 16,
})

# ------------------ Global config ------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

# ------------------ Utils ------------------

def model_device(model: nn.Module) -> torch.device:
    """Return the device of the first parameter of a model (CPU if no params)."""
    try:
        return next(model.parameters()).device
    except StopIteration:
        return torch.device("cpu")

def plot_curves(histories, ylabel, outpath):
    """histories: dict[name] -> list of values per epoch"""
    plt.figure()
    for name, vals in histories.items():
        plt.plot(range(1, len(vals)+1), vals, marker="o", label=name)
    plt.xlabel("epoch")
    plt.ylabel(ylabel)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(outpath, bbox_inches="tight")
    print(f"[plot] wrote {outpath}")

def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', outpath='cm.pdf'):
    if normalize:
        cm = cm.astype('float') / (cm.sum(axis=1, keepdims=True) + 1e-12)
    plt.figure()
    plt.imshow(cm, interpolation='nearest')
    # plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, ha="right")
    plt.yticks(tick_marks, classes)
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2. if cm.max() > 0 else 0.5
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig(outpath, bbox_inches="tight")
    print(f"[plot] wrote {outpath}")

def pool_to_n(arr2d, n):
    v = np.asarray(arr2d).reshape(-1)
    splits = np.array_split(v, n)
    return np.array([float(s.mean()) for s in splits], dtype=np.float32)

# ------------------ CIFAR Dataset wrappers ------------------

class CIFARSubset(Dataset):
    def __init__(self, split, classes=(0,1), limit=None, root="./_cache_cifar",
                 return_image=True, n_wires=6):
        from torchvision import datasets
        self.classes = list(classes)
        self.return_image = return_image
        self.n_wires = n_wires

        self.ds = datasets.CIFAR10(root=root, train=(split=='train'),
                                   download=True, transform=None)
        self.X, self.y = [], []
        for img, label in self.ds:
            if label in self.classes:
                self.X.append(img)  # PIL
                self.y.append(self.classes.index(label))
                if limit and len(self.X) >= limit:
                    break

        # transforms for CNN/QCNN image branch
        self.tx_img = T.Compose([
            T.ToTensor(),  # [0,1], CxHxW
        ])

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        img = self.X[idx]              # PIL
        y = self.y[idx]
        if self.return_image:
            x_img = self.tx_img(img)   # torch tensor CxHxW
        else:
            x_img = None
        # pooled grayscale vector (for QNN / QCNN quantum input)
        gray = img.convert("L")
        vec = pool_to_n(np.asarray(gray, dtype=float)/255.0, self.n_wires)  # n_wires
        x_vec = torch.from_numpy(vec)  # float32
        return x_img, x_vec, torch.tensor(y, dtype=torch.long)

# ------------------ EuroSAT Dataset wrapper (deterministic split) ------------------

class EuroSATSubset(Dataset):
    """
    Deterministic split using filename-hash:
      - train:  hash % 4 != 0  (75%)
      - val:    hash % 4 == 0  (25%)
    Applies per-split, per-class limits (limit // num_classes) and returns
    image tensors for CNN/QCNN plus a pooled grayscale vector of length n_wires
    for QNN/QCNN quantum input.
    """
    def __init__(self, csv_path, rgb_root, tif_root, classes=None,
                 split='train', limit=None, return_image=True, n_wires=6):
        import pandas as pd
        self.return_image = return_image
        self.n_wires = n_wires
        self.split = split
        df = pd.read_csv(csv_path)

        # Resolve class set
        if classes is None:
            uniq = sorted(df["Label"].unique().tolist())
            classes = uniq[:2]
            print(f"[eurosat] classes defaulted to: {classes}")
        self.classes = list(classes)

        # Per-split limits per class
        if limit is None:
            per_class_limit = float("inf")
        else:
            per_class_limit = max(1, limit // max(1, len(self.classes)))

        def is_val(filename: str) -> bool:
            # deterministic split on filename
            return (hash(filename) % 4) == 0

        # Collect rows for this split
        rows = []
        per_class_counts = {c: 0 for c in self.classes}
        for _, row in df.iterrows():
            lab = row["Label"]
            try:
                lab = int(lab)
            except Exception:
                pass
            if lab not in self.classes:
                continue

            fname = row["Filename"]
            rgb_path = os.path.join(rgb_root, fname)
            if not os.path.isfile(rgb_path):
                continue

            want_val = is_val(fname)
            if (self.split == "val" and not want_val) or (self.split == "train" and want_val):
                continue

            if per_class_counts[lab] >= per_class_limit:
                continue

            rows.append((rgb_path, lab))
            per_class_counts[lab] += 1

        if len(rows) == 0:
            raise RuntimeError(
                f"[eurosat] No samples found for split='{self.split}'. "
                f"Check paths and class ids. Requested classes={self.classes}"
            )

        self.rows = rows
        self.tx_img = T.Compose([T.ToTensor()])

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        rgb_path, lab = self.rows[idx]
        img = Image.open(rgb_path).convert("RGB")
        x_img = self.tx_img(img) if self.return_image else None

        # quantum vector via grayscale pooling (length = n_wires)
        gray = img.convert("L")
        vec = pool_to_n(np.asarray(gray, dtype=float)/255.0, self.n_wires).astype(np.float32)
        x_vec = torch.from_numpy(vec)

        # remap label to 0..C-1 based on self.classes order
        y = torch.tensor(self.classes.index(lab), dtype=torch.long)
        return x_img, x_vec, y

# ------------------ Models ------------------

class SmallCNN(nn.Module):
    """Compact CNN for 2-class classification on 32x32-like images."""
    def __init__(self, in_ch=3, num_classes=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1,1)),
        )
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        h = self.net(x).flatten(1)
        return self.fc(h)

# ---- Quantum layer via PennyLane Torch integration ----

def make_quantum_torch_layer(n_wires=6, n_layers=3, n_outputs=None):
    """Returns qlayer (TorchLayer) that maps R^{n_wires} -> R^{n_wires} expectations."""
    dev = qml.device("default.qubit", wires=n_wires, shots=None)

    @qml.qnode(dev, interface="torch")
    def qnode(inputs, weights):
        qml.AngleEmbedding(inputs, wires=range(n_wires), rotation="Y")
        qml.StronglyEntanglingLayers(weights, wires=range(n_wires))
        return [qml.expval(qml.PauliZ(i)) for i in range(n_wires)]

    weight_shapes = {"weights": (n_layers, n_wires, 3)}
    qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)
    return qlayer  # outputs n_wires features

class QNNHead(nn.Module):
    """Pure QNN classifier over pooled vector (length n_wires)."""
    def __init__(self, n_wires=6, n_layers=3, num_classes=2):
        super().__init__()
        self.qlayer = make_quantum_torch_layer(n_wires, n_layers)
        self.fc = nn.Linear(n_wires, num_classes)

    def forward(self, x_vec):  # x_vec: (B, n_wires)
        qf = self.qlayer(x_vec)
        return self.fc(qf)

class QCNN(nn.Module):
    """CNN feature extractor -> quantum layer -> linear head."""
    def __init__(self, in_ch=3, n_wires=6, n_layers=3, num_classes=2):
        super().__init__()
        self.feat = nn.Sequential(
            nn.Conv2d(in_ch, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((1,1)),
        )
        self.to_vec = nn.Linear(32, n_wires)  # compress to n_wires inputs
        self.quantum = make_quantum_torch_layer(n_wires, n_layers)
        self.fc = nn.Linear(n_wires, num_classes)

    def forward(self, x_img):  # x_img: (B,C,H,W)
        h = self.feat(x_img).flatten(1)     # (B, 32)
        v = self.to_vec(h)                  # (B, n_wires)
        qf = self.quantum(v)                # (B, n_wires)
        return self.fc(qf)

# ------------------ Training / Eval ------------------

def run_training(model, train_loader, val_loader, epochs=10, lr=1e-3, name="model"):
    model = model.to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=lr)
    ce = nn.CrossEntropyLoss()
    acc_hist, loss_hist = [], []

    for ep in range(epochs):
        model.train()
        for batch in train_loader:
            x_img, x_vec, y = batch
            y = y.to(DEVICE)

            if isinstance(model, SmallCNN) or isinstance(model, QCNN):
                x = (x_img if x_img is not None else torch.zeros((y.shape[0],3,32,32))).to(DEVICE)
                logits = model(x)
            elif isinstance(model, QNNHead):
                x = x_vec.to(DEVICE).float()
                logits = model(x)
            else:
                raise RuntimeError("Unknown model type")

            loss = ce(logits, y)
            opt.zero_grad(); loss.backward(); opt.step()

        # validation
        model.eval()
        tot, correct, vloss = 0, 0, 0.0
        with torch.no_grad():
            for batch in val_loader:
                x_img, x_vec, y = batch
                y = y.to(DEVICE)
                if isinstance(model, SmallCNN) or isinstance(model, QCNN):
                    x = (x_img if x_img is not None else torch.zeros((y.shape[0],3,32,32))).to(DEVICE)
                    logits = model(x)
                else:
                    x = x_vec.to(DEVICE).float()
                    logits = model(x)
                loss = ce(logits, y)
                vloss += float(loss) * y.size(0)
                pred = logits.argmax(1)
                correct += int((pred == y).sum().item()); tot += y.size(0)
        acc = 100.0 * correct / max(1, tot)
        vloss = vloss / max(1, tot)
        acc_hist.append(acc); loss_hist.append(vloss)
        print(f"[{name}] epoch {ep+1:02d}: val_acc={acc:5.2f}%  val_loss={vloss:.4f}")

    return acc_hist, loss_hist, model

def get_preds(model, data_loader):
    """Run inference for confusion matrix using the model's own device."""
    mdl_dev = model_device(model)
    model = model.to(mdl_dev)
    model.eval()
    ys, yhats = [], []
    with torch.no_grad():
        for batch in data_loader:
            x_img, x_vec, y = batch
            if isinstance(model, SmallCNN) or isinstance(model, QCNN):
                if x_img is None:
                    x = torch.zeros((y.shape[0], 3, 32, 32), dtype=torch.float32, device=mdl_dev)
                else:
                    x = x_img.to(mdl_dev, non_blocking=True)
                logits = model(x)
            else:  # QNNHead
                x = x_vec.to(mdl_dev, non_blocking=True).float()
                logits = model(x)
            pred = logits.argmax(1).cpu().numpy()
            ys.extend(y.numpy().tolist()); yhats.extend(pred.tolist())
    return np.array(ys), np.array(yhats)

# ------------------ Main ------------------

def main():
    ap = argparse.ArgumentParser(description="Compare CNN vs QNN vs QCNN (6 qubits, 3 layers).")
    ap.add_argument("--dataset", choices=["cifar","eurosat"], required=True)
    ap.add_argument("--epochs", type=int, default=10)
    ap.add_argument("--batch", type=int, default=64)
    ap.add_argument("--lr", type=float, default=1e-3)
    ap.add_argument("--n_wires", type=int, default=6)
    ap.add_argument("--n_layers", type=int, default=3)
    # CIFAR
    ap.add_argument("--classes", type=int, nargs="+", default=[0,1])
    ap.add_argument("--max_train", type=int, default=1000)
    ap.add_argument("--max_val", type=int, default=400)
    ap.add_argument("--cifar_root", type=str, default="./_cache_cifar")
    # EuroSAT
    ap.add_argument("--eurosat_csv", type=str, default="")
    ap.add_argument("--eurosat_rgb_root", type=str, default="")
    ap.add_argument("--eurosat_tif_root", type=str, default="")
    ap.add_argument("--eurosat_classes", type=int, nargs="+", default=None)
    ap.add_argument("--out_prefix", type=str, default="")
    args = ap.parse_args()

    n_wires, n_layers = args.n_wires, args.n_layers
    out_prefix = args.out_prefix or (f"compare_{args.dataset}_6q_3L")

    # ----- Data -----
    if args.dataset == "cifar":
        train_ds = CIFARSubset("train", classes=tuple(args.classes),
                               limit=args.max_train, root=args.cifar_root,
                               return_image=True, n_wires=n_wires)
        val_ds   = CIFARSubset("val",   classes=tuple(args.classes),
                               limit=args.max_val, root=args.cifar_root,
                               return_image=True, n_wires=n_wires)
        in_ch, num_classes = 3, len(args.classes)
        class_names = [str(c) for c in args.classes]
    else:
        if not (args.eurosat_csv and args.eurosat_rgb_root and args.eurosat_tif_root):
            print("[error] Provide EuroSAT paths: --eurosat_csv --eurosat_rgb_root --eurosat_tif_root", file=sys.stderr)
            sys.exit(1)
        train_ds = EuroSATSubset(args.eurosat_csv, args.eurosat_rgb_root, args.eurosat_tif_root,
                                 classes=args.eurosat_classes, split="train",
                                 limit=args.max_train, return_image=True, n_wires=n_wires)
        val_ds   = EuroSATSubset(args.eurosat_csv, args.eurosat_rgb_root, args.eurosat_tif_root,
                                 classes=args.eurosat_classes, split="val",
                                 limit=args.max_val, return_image=True, n_wires=n_wires)
        in_ch, num_classes = 3, (len(args.eurosat_classes) if args.eurosat_classes else 2)
        class_names = [str(c) for c in (args.eurosat_classes or train_ds.classes)]

    train_loader = DataLoader(train_ds, batch_size=args.batch, shuffle=True, num_workers=2, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=args.batch, shuffle=False, num_workers=2, pin_memory=True)

    print(f"[data] train size={len(train_ds)}  val size={len(val_ds)}  classes={class_names}")

    # ----- Models -----
    # 1) Only CNN
    model_cnn = SmallCNN(in_ch=in_ch, num_classes=num_classes)
    # 2) Only QNN (on pooled vector)
    model_qnn = QNNHead(n_wires=n_wires, n_layers=n_layers, num_classes=num_classes)
    # 3) QCNN (CNN features -> quantum -> linear)
    model_qcnn = QCNN(in_ch=in_ch, n_wires=n_wires, n_layers=n_layers, num_classes=num_classes)

    # ----- Train -----
    acc_cnn,  loss_cnn,  model_cnn  = run_training(model_cnn,  train_loader, val_loader,
                                                    epochs=args.epochs, lr=args.lr, name="QCNN")
    acc_qnn,  loss_qnn,  model_qnn  = run_training(model_qnn,  train_loader, val_loader,
                                                    epochs=args.epochs, lr=args.lr, name="QNN")
    acc_qcnn, loss_qcnn, model_qcnn = run_training(model_qcnn, train_loader, val_loader,
                                                    epochs=args.epochs, lr=args.lr, name="CNN")

    # ----- Curves -----
    plot_curves({"CNN": acc_qcnn, "QNN": acc_qnn, "QCNN": acc_cnn},
                ylabel="accuracy (%)", outpath=Path(f"{out_prefix}_acc_vs_epoch.pdf"))
    plot_curves({"CNN": loss_qcnn, "QNN": loss_qnn, "QCNN": loss_cnn},
                ylabel="cross-entropy loss", outpath=Path(f"{out_prefix}_loss_vs_epoch.pdf"))

    # ----- Confusion matrices -----
    # Use model's own device in get_preds to avoid mismatches
    y_true, y_pred = get_preds(model_cnn, DataLoader(val_ds, batch_size=128, shuffle=False))
    cm = confusion_matrix(y_true, y_pred)
    plot_confusion_matrix(cm, class_names, normalize=False,
                          title="CNN Confusion Matrix", outpath=Path(f"{out_prefix}_cm_cnn.pdf"))

    y_true, y_pred = get_preds(model_qnn, DataLoader(val_ds, batch_size=128, shuffle=False))
    cm = confusion_matrix(y_true, y_pred)
    plot_confusion_matrix(cm, class_names, normalize=False,
                          title="QNN Confusion Matrix", outpath=Path(f"{out_prefix}_cm_qnn.pdf"))

    y_true, y_pred = get_preds(model_qcnn, DataLoader(val_ds, batch_size=128, shuffle=False))
    cm = confusion_matrix(y_true, y_pred)
    plot_confusion_matrix(cm, class_names, normalize=False,
                          title="QCNN Confusion Matrix", outpath=Path(f"{out_prefix}_cm_qcnn.pdf"))

    print("[done] Comparison complete.")

if __name__ == "__main__":
    main()
