import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from config import config
from models.model_factory import load_mnist_model, load_cifar10_model, cifar10_class_labels, load_gtsrb_model, \
    gtsrb_class_labels, load_taxinet_model
import numpy as np
from torch.utils.data import Subset, DataLoader, Dataset, TensorDataset
import logging
import pickle


def to_nchw(X: np.ndarray) -> np.ndarray:
    assert X.ndim == 4, f"Expected 4D array, got {X.shape}"
    # If it's already NCHW (C in dim=1), do nothing
    if X.shape[1] in (1, 3):
        return X.astype(np.float32, copy=False)
    # If it's NHWC (C in dim=-1), transpose
    if X.shape[-1] in (1, 3):
        return np.transpose(X, (0, 3, 1, 2)).astype(np.float32, copy=False)
    raise ValueError(f"Cannot infer channels axis for shape {X.shape}")


# ---- minimal dataset (expects NCHW X and (N,1) Y) ----
class TaxiNetDataset(Dataset):
    def __init__(self, X: np.ndarray, Y: np.ndarray):
        assert X.ndim == 4 and Y.ndim == 2 and Y.shape[1] == 1, f"Shapes X={X.shape}, Y={Y.shape}"
        self.X = torch.from_numpy(X.astype(np.float32, copy=False))
        self.Y = torch.from_numpy(Y.astype(np.float32, copy=False))

    def __len__(self): return self.X.shape[0]

    def __getitem__(self, i): return self.X[i], self.Y[i]





def prepare_vision_batch(dataset_name, net, test_data, save_path, device, sample_ids, class_labels=None, verbose=0):
    """Prepares and saves a vision dataset sample (CIFAR-10, GTSRB) along with its runner-up label."""
    net.eval()
    samples = sample_ids if isinstance(sample_ids, (list, tuple)) else [sample_ids]
    X_list, label_list, predicted_list, runner_list, winner_logit_lst, diffs_list = [], [], [], [], [], []
    with torch.no_grad():
        for sid in samples:
            x, label = test_data[sid]
            x = x.unsqueeze(0).to(device)                # (1, 3, 32, 32) or similar
            out = net(x)                               # (1, C)
            top2 = torch.topk(out, 2, dim=1).indices[0]
            winner = top2[0].item()
            runner = top2[1].item()
            winner_logit = out[0, winner].item()
            runner_logit = out[0, runner].item()
            diff = winner_logit - runner_logit

            X_list.append(x.cpu().numpy())
            label_list.append(label)
            predicted_list.append(winner)
            runner_list.append(runner)
            winner_logit_lst.append(winner_logit)
            diffs_list.append(diff)

    # Stack
    X_np = np.vstack(X_list)
    labels_np = np.array(label_list, dtype=np.int64)
    predicted_np = np.array(predicted_list, dtype=np.int64)
    runner_np = np.array(runner_list, dtype=np.int64)
    winner_logit_np = np.array(winner_logit_lst, dtype=np.float32)
    diffs = torch.tensor(diffs_list, dtype=torch.float32)

    batch_dict = {
        'X': X_np,
        'label': labels_np,
        'predicted': predicted_np,
        'runner_up': runner_np,
        'winner_logit': winner_logit_np
    }
    np.save(save_path, batch_dict)
    X_batch = torch.from_numpy(X_np).float()

    if verbose:
        logging.info(f"saved {dataset_name.upper()} samples {sample_ids} to {save_path}")
        if class_labels:
            # This assumes single sample when printing true label, which is what original code did.
            # For multiple samples, it will use the last label.
            logging.info(f"true label: {labels_np} ({[class_labels[l] for l in labels_np]})")
            logging.info(f"predicted labels: {predicted_np} ({[class_labels[p] for p in predicted_np]})")
            logging.info(f"runner-up labels: {runner_np} ({[class_labels[r] for r in runner_np]})")
        else:
            logging.info(f"true label: {labels_np}")
            logging.info(f"predicted labels: {predicted_np}")
            logging.info(f"runner-up labels: {runner_np}")

    return X_batch, diffs


def prepare_mnist_batch(net, test_data, save_path, device, sample_ids, target=5):
    """Prepares and saves an MNIST sample along with its runner-up label."""
    samples = sample_ids if isinstance(sample_ids, (list, tuple)) else [sample_ids]
    net.eval()
    X_list, label_list, predicted_list, runner_list, winner_logit_lst, diffs_list = [], [], [], [], [], []
    with torch.no_grad():
        for sid in samples:
            x, label = test_data[sid]
            x = x.view(1, -1).to(device)  # Reshape to (1, 784)
            out = net(x)  # (1, C)
            top2 = torch.topk(out, 2, dim=1).indices[0]
            winner = top2[0].item()
            runner = top2[1].item()
            winner_logit = out[0, winner].item()
            runner_logit = out[0, runner].item()
            diff = winner_logit - runner_logit

            X_list.append(x.cpu().numpy())  # (1, 784)
            label_list.append(label)
            predicted_list.append(winner)
            runner_list.append(runner)
            winner_logit_lst.append(winner_logit)
            diffs_list.append(diff)

    # Stack
    X_np = np.vstack(X_list)  # shape (B, 784)
    labels_np = np.array(label_list, dtype=np.int64)  # (B,)
    predicted_np = np.array(predicted_list, dtype=np.int64)  # (B,)
    runner_np = np.array(runner_list, dtype=np.int64)  # (B,)
    winner_logit_np = np.array(winner_logit_lst, dtype=np.float32)  # (B,)
    diffs = torch.tensor(diffs_list, dtype=torch.float32)  # (B,)

    batch_dict = {
        'X': X_np,
        'label': labels_np,
        'predicted': predicted_np,
        'runner_up': runner_np,
        'winner_logit': winner_logit_np
    }
    np.save(save_path, batch_dict)

    X_batch = torch.from_numpy(X_np).float()  # (B, 784)
    return X_batch, diffs



def prepare_taxinet_batch(model, test_data, save_path, device, sample_ids):
    """Save Taxinet samples with predictions/errors using np.save on a dict (like prepare_vision_batch)."""
    model.eval()
    samples = sample_ids if isinstance(sample_ids, (list, tuple)) else [sample_ids]

    X_list, y_true_list, y_pred_list, abs_err_list = [], [], [], []

    with torch.no_grad():
        for sid in samples:
            x, y = test_data[sid]             # x: (C,H,W) tensor float32, y: (1,) tensor float32
            x1 = x.unsqueeze(0).to(device)    # (1,C,H,W)
            pred = model(x1).squeeze().item() # scalar prediction
            yv = y.squeeze().item()           # scalar label

            X_list.append(x1.cpu().numpy())   # (1,C,H,W)
            y_true_list.append(yv)
            y_pred_list.append(pred)
            abs_err_list.append(abs(pred - yv))

    X_np      = np.vstack(X_list).astype(np.float32, copy=False)  # (B,C,H,W)
    y_true_np = np.array(y_true_list, dtype=np.float32)           # (B,)
    y_pred_np = np.array(y_pred_list, dtype=np.float32)           # (B,)
    abs_e_np  = np.array(abs_err_list, dtype=np.float32)          # (B,)

    batch_dict = {
        "X": X_np,
        "y_true": y_true_np,
        "y_pred": y_pred_np,
        "abs_error": abs_e_np,
    }
    np.save(save_path, batch_dict)

    X_batch = torch.from_numpy(X_np).float()                 # (B,C,H,W)
    diffs   = torch.from_numpy(y_pred_np - y_true_np).float()  # (B,)
    return X_batch, diffs


def prepare_batch(dataset_name, model, test_data, exp_paths, device, sample_ids):
    """
    Prepares and saves a batch for the given dataset.
    Returns (X_batch, winner_runner_logit_diff).
    """
    if dataset_name == "mnist":
        sample_path = exp_paths['mnist_sample_path']
        return prepare_mnist_batch(model, test_data, sample_path, device, sample_ids)
    elif dataset_name in ["cifar10-big", "cifar10-small", "cifar10"]:
        sample_path = exp_paths['cifar10_sample_path']
        return prepare_vision_batch(dataset_name, model, test_data, sample_path, device, sample_ids, cifar10_class_labels)
    elif dataset_name == "taxinet":
        sample_path = exp_paths['taxinet_sample_path']
        return prepare_taxinet_batch(model, test_data, sample_path, device, sample_ids)
    elif dataset_name == "gtsrb":
        sample_path = exp_paths['gtsrb_sample_path']
        return prepare_vision_batch(dataset_name, model, test_data, sample_path, device, sample_ids, gtsrb_class_labels)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

def load_dataset(dataset_name, batch_size=100):
    if dataset_name == "mnist":
        transform = transforms.ToTensor()
        train_data = dsets.MNIST(root=config['paths']['data_root_path'], train=True, transform=transform, download=True)
        test_data = dsets.MNIST(root=config['paths']['data_root_path'], train=False, transform=transform, download=True)
        train_gen = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
        test_gen = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False)
        return train_gen, test_data, test_gen

    elif dataset_name == "gtsrb":
        # return verix_gtsrb_dset_oading(batch_size)
        transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            # transforms.Normalize(mean=(0.3403, 0.3121, 0.3214),
            #  std =(0.2724, 0.2608, 0.2669)),
        ])

        # 1. Load the official splits
        train_data = dsets.GTSRB(root=config['paths']['data_root_path'], split="train", transform=transform, download=True)
        test_data = dsets.GTSRB(root=config['paths']['data_root_path'], split="test", transform=transform, download=True)
        train_gen = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
        test_gen = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False)

        return train_gen, test_data, test_gen

    elif dataset_name == "taxinet":
        with open(config['paths']['taxinet_data_path'], 'rb') as f:
            taxinet_data = pickle.load(f)

        x_train = to_nchw(taxinet_data["x_train"])
        x_test = to_nchw(taxinet_data["x_test"])

        # make labels float32 once
        y_train = taxinet_data["y_train"].astype(np.float32, copy=False)
        y_test = taxinet_data["y_test"].astype(np.float32, copy=False)

        # train <- train, test <- test  (was swapped)
        train_data = TaxiNetDataset(x_train, y_train)
        test_data = TaxiNetDataset(x_test, y_test)

        train_gen = DataLoader(train_data, batch_size=batch_size, shuffle=True)
        test_gen = DataLoader(test_data, batch_size=batch_size, shuffle=False)

        return train_gen, test_data, test_gen


    elif dataset_name in ["cifar10-big", "cifar10-small", "cifar10"]:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
        ])
        train_data = dsets.CIFAR10(root=config['paths']['data_root_path'], train=True, download=True, transform=transform)
        test_data = dsets.CIFAR10(root=config['paths']['data_root_path'], train=False, download=True, transform=transform)
        train_gen = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
        test_gen = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False)
        return train_gen, test_data, test_gen

    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")


def verix_gtsrb_dset_oading(batch_size):
    with open(config['paths']['gtsrb_data_path'], 'rb') as f:
        gtsrb = pickle.load(f)
    # Cast to float32 / int64 and scale pixels to [0,1]
    x_train = torch.from_numpy(gtsrb['x_train'].astype('float32') / 255.0).permute(0, 3, 1, 2)
    y_train = torch.from_numpy(gtsrb['y_train'].astype('int64'))
    x_test = torch.from_numpy(gtsrb['x_test'].astype('float32') / 255.0).permute(0, 3, 1, 2)
    y_test = torch.from_numpy(gtsrb['y_test'].astype('int64'))
    train_data = TensorDataset(x_train, y_train)
    test_data = TensorDataset(x_test, y_test)
    train_gen = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_gen = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    return train_gen, test_data, test_gen


def load_model(dataset_name, device):
    if dataset_name not in ("mnist", "cifar10-big", "cifar10-small", "gtsrb", "taxinet"):
        raise ValueError(f"Unknown dataset: {dataset_name}")

    model_path = config['paths'][f"{dataset_name}_state_dict_path"]

    if dataset_name == "mnist":
        model = load_mnist_model(model_path=model_path, device=device)
    elif dataset_name in ["cifar10-big", "cifar10-small"]:
        model = load_cifar10_model(model_path=model_path, device=device, model_type=dataset_name)
    elif dataset_name == "taxinet":
        model = load_taxinet_model(model_path=model_path, device=device)
    elif dataset_name == "gtsrb":
        model = load_gtsrb_model(model_path=model_path, device=device)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    return model, model_path

def load_dataset_and_model(dataset_name, device, exp_paths, batch_size=100, sample_ids=None):
    train_gen, test_data, test_gen = load_dataset(dataset_name, batch_size)
    model, model_path = load_model(dataset_name, device)

    if sample_ids is None:
        if dataset_name == "mnist":
            sample_ids = [9349]
        elif dataset_name == "taxinet":
            sample_ids = [453]
        elif dataset_name == "gtsrb":
            sample_ids = [675]  # Default sample ID for GTSRB
        else: # cifar-10
            sample_ids = [114]

    X, winner_runner_logit_diff = prepare_batch(dataset_name, model, test_data, exp_paths, device, sample_ids)

    if dataset_name == "mnist":
        batch_path = exp_paths['mnist_sample_path']
    elif dataset_name == "taxinet":
        batch_path = exp_paths['taxinet_sample_path']
    elif dataset_name == "gtsrb":
        batch_path = exp_paths['gtsrb_sample_path']
    else: # cifar-10
        batch_path = exp_paths['cifar10_sample_path']
    return model, model_path, train_gen, test_gen, test_data, batch_path, X, winner_runner_logit_diff

def load_sample(dataset_name, model, test_data, device, sample_id, exp_paths):
    # deprecated: use prepare_batch for both single and multiple samples
    return prepare_batch(dataset_name, model, test_data, exp_paths, device, [sample_id])

def get_mean_activations_cifar10(net, train_gen, device, num_samples=100):
    gen = torch.Generator().manual_seed(42)
    dataset = train_gen.dataset
    num_total = len(dataset)
    rand_idxs = torch.randperm(num_total, generator=gen)[:100].tolist()

    fixed_subset = Subset(dataset, rand_idxs)
    fixed_loader = DataLoader(
        fixed_subset,
        batch_size=train_gen.batch_size,
        shuffle=False,
        num_workers=train_gen.num_workers,
        pin_memory=train_gen.pin_memory
    )
    return compute_mean_activations_vision(net, fixed_loader, device, num_samples)


def get_mean_activations_mnist(net, train_gen, device, num_samples=100):

    # Seed for reproducibility
    gen = torch.Generator().manual_seed(42)
    dataset = train_gen.dataset
    num_total = len(dataset)
    rand_idxs = torch.randperm(num_total, generator=gen)[:100].tolist()

    fixed_subset = Subset(dataset, rand_idxs)
    fixed_loader = DataLoader(
        fixed_subset,
        batch_size=train_gen.batch_size,
        shuffle=False,  # keep the same order each run
        num_workers=train_gen.num_workers,
        pin_memory=train_gen.pin_memory
    )

    return compute_mean_activations_mnist(net, fixed_loader, device, num_samples)
def compute_mean_activations_mnist(net, data_loader, device, num_samples):
    net.eval()
    mean_activations, weight_contributions = {}, {}

    for name, layer in net.named_modules():
        # print(f"name: {name}, layer:{layer}")
        if isinstance(layer, nn.Linear):
            mean_activations[name] = torch.zeros(layer.out_features, device=device)
            weight_contributions[name] = torch.zeros_like(layer.weight)

    sample_count = 0
    with torch.no_grad():
        for images, _ in data_loader:
            images = images.to(device)
            batch_size = images.shape[0]
            output = images.view(batch_size, -1)

            for name, layer in net.named_modules():
                if isinstance(layer, nn.Linear):
                    layer_input = output
                    output = layer(output)
                    mean_activations[name] += output.sum(dim=0)  # Accumulate sum
                    weight_contributions[name] += (layer_input.unsqueeze(2) * layer.weight.T).sum(dim=0).T
            sample_count += batch_size
            if sample_count >= num_samples:
                break

    for name in mean_activations:
        mean_activations[name] /= sample_count
        weight_contributions[name] /= sample_count

    return mean_activations, weight_contributions


def compute_mean_activations_vision(net, data_loader, device, num_samples=100):
    net.eval()
    mean_act = {}  # layer_name -> accumulated (C, H, W)
    count_act = {}  # layer_name -> total #pixels per channel
    hooks = []
    sample_count = 0

    for name, module in net.named_modules():
        if isinstance(module, nn.Conv2d):
            mean_act[name] = None
            count_act[name] = 0

            def hook_fn(mod, inp, out, name=name):
                # out: (B, C, H, W)
                B, C, H, W = out.shape
                sum_map = out.detach().sum(dim=0)  # (C, H, W)
                if mean_act[name] is None:
                    mean_act[name] = torch.zeros_like(sum_map)
                mean_act[name] += sum_map
                count_act[name] += B  # accumulate total number of samples

            hooks.append(module.register_forward_hook(hook_fn))

    with torch.no_grad():
        for imgs, _ in data_loader:
            imgs = imgs.to(device)
            B = imgs.size(0)
            net(imgs)
            sample_count += B
            if sample_count >= num_samples:
                break

    for h in hooks:
        h.remove()

    for name in mean_act:
        mean_act[name] /= count_act[name]  # now it's avg over batch

    return mean_act, None  # mean_act[layer_name]: (C, H, W)

def get_mean_activations_gtsrb(net, train_gen, device, num_samples=100):
    gen = torch.Generator().manual_seed(42)
    dataset = train_gen.dataset
    num_total = len(dataset)
    rand_idxs = torch.randperm(num_total, generator=gen)[:num_samples].tolist()

    fixed_subset = Subset(dataset, rand_idxs)
    fixed_loader = DataLoader(
        fixed_subset,
        batch_size=train_gen.batch_size,
        shuffle=False,
        num_workers=train_gen.num_workers,
        pin_memory=train_gen.pin_memory
    )
    return compute_mean_activations_vision(net, fixed_loader, device, num_samples)

def get_mean_activations_taxinet(net, train_gen, device, num_samples=100):
    gen = torch.Generator().manual_seed(42)
    dataset = train_gen.dataset
    num_total = len(dataset)
    rand_idxs = torch.randperm(num_total, generator=gen)[:num_samples].tolist()

    fixed_subset = Subset(dataset, rand_idxs)
    fixed_loader = DataLoader(
        fixed_subset,
        batch_size=train_gen.batch_size,
        shuffle=False,
        num_workers=train_gen.num_workers,
        pin_memory=train_gen.pin_memory
    )
    return compute_mean_activations_vision(net, fixed_loader, device, num_samples)