import torchvision.transforms as transforms
from torchvision import datasets
from sklearn.utils import check_random_state
import pickle
import torch
import numpy as np
import builtins
import os
from functools import lru_cache
from exp.real_world_dataloader import load_real_world_data

builtins.IS_LOG = True
builtins.DATALOG_COUNT = 0
def log(*args, sep=False, **kwargs):
    if builtins.IS_LOG:
        if sep:
            msg = " ".join(map(str, args))
            total_len = 80
            msg_len = len(msg)
            if msg_len >= total_len:
                print(msg)
            else:
                eq_len = total_len - msg_len
                left = eq_len // 2
                right = eq_len - left
                print(f"{'=' * left}{msg}{'=' * right}")
        else:
            print(*args, **kwargs)

def load_data(dataset, N_more, N_less, rs, check, need_labels=True):
    if 'blob' in dataset:
        X, Y, X_labels = sample_blob(N_more, N_less, rs, check)
    elif 'higgs' in dataset:
        X, Y, X_labels = sample_higgs(N_more, N_less, rs, check)
    elif 'cifar10' in dataset:
        X, Y, X_labels = load_cifar10(N_more, N_less, rs, check)
    # elif dataset in ['asvspoof_wavlm', 'asvspoof_hubert', 'mgt_e5', 'mgt_bge',
    #                  'deepfake_clip', 'deepfake_efficientnet']:
    #     X, Y, X_labels = load_real_world_data(dataset, N_more, N_less, rs, check)
    elif 'asvspoof_wavlm' in dataset:
        X, Y, X_labels = load_asvspoof(N_more, N_less, rs, check)
    elif 'deepfake_clip' in dataset:
        X, Y, X_labels = load_deepfake(N_more, N_less, rs, check)
    else:
        raise ValueError(f"Unknown dataset: {dataset}")

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    if X.shape[1:] != Y.shape[1:]:
        raise ValueError(f"X and Y must have the same feature size, but got X: {X.shape[1:]}, Y: {Y.shape[1:]}")

    if need_labels:
        return X, Y, X_labels
    else:
        return X, Y
        
@lru_cache(maxsize=1)
def load_embeddings(task, embedding_type):
    """
    Load pre-extracted embeddings.
    
    Args:
        task: 'asvspoof', 'mgt', or 'deepfake'
        embedding_type: 'wavlm', 'e5', 'clip', etc.
    
    Returns:
        X (real/bonafide), Y (fake/spoof) numpy arrays
    """
    paths = {
        'asvspoof': ('asvspoof/embeddings', 'bonafide', 'spoof'),
        'mgt': ('mgt/embeddings', 'human', 'machine'),
        'deepfake': ('deepfake/embeddings', 'real', 'fake'),
    }
    
    emb_dir, real_suffix, fake_suffix = paths[task]
    base = os.path.join(builtins.ROOT_PATH, "data", emb_dir)
    
    real_path = os.path.join(base, f'{embedding_type}_{real_suffix}.npy')
    fake_path = os.path.join(base, f'{embedding_type}_{fake_suffix}.npy')
    
    if os.path.exists(real_path) and os.path.exists(fake_path):
        X = np.load(real_path)
        Y = np.load(fake_path)
        # print(f"Loaded {task}/{embedding_type}: X={X.shape}, Y={Y.shape}")
        return X, Y
    else:
        print(f"Not found: {real_path} or {fake_path}")
        return None, None

def load_asvspoof(N_more, N_less, rs, check):
    X, Y = load_embeddings('asvspoof', 'wavlm')
    X_labels = None
    np.random.seed(rs)
    X_idx = np.random.choice(len(X), size=min(N_more, len(X)), replace=False)
    X_asvspoof = X[X_idx]
    if check:
        Y_idx = np.random.choice(len(Y), size=min(N_less, len(Y)), replace=False)
        Y_asvspoof = Y[Y_idx]
    else:
        remaining = np.setdiff1d(np.arange(len(X)), X_idx)
        Y_indices = np.random.choice(remaining, size=min(N_less, len(remaining)), replace=False)
        Y_asvspoof = X[Y_indices]
    return X_asvspoof, Y_asvspoof, X_labels

def load_deepfake(N_more, N_less, rs, check):
    X, Y = load_embeddings('deepfake', 'clip')
    X_labels = None
    np.random.seed(rs)
    X_idx = np.random.choice(len(X), size=min(N_more, len(X)), replace=False)
    X_deepfake = X[X_idx]
    if check:
        Y_idx = np.random.choice(len(Y), size=min(N_less, len(Y)), replace=False)
        Y_deepfake = Y[Y_idx]
    else:
        remaining = np.setdiff1d(np.arange(len(X)), X_idx)
        Y_indices = np.random.choice(remaining, size=min(N_less, len(remaining)), replace=False)
        Y_deepfake = X[Y_indices]
    
    return X_deepfake, Y_deepfake, X_labels

def check_device():
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    return device

def create_grid(n_rows, n_cols):
    """
    Create a grid of points in a 2D space

    Example:
    [[0, 0], [0, 1], [0, 2], 
     [1, 0], [1, 1], [1, 2], 
     [2, 0], [2, 1], [2, 2]]
     """

    return np.array([[i, j] for i in range(n_rows) for j in range(n_cols)])

def create_cov_matrix(n_locs=9, variance=0.03, min_corr=0.02):
    n_side = n_locs // 2
    correlations = min_corr + np.arange(n_side) * 0.002

    correlations = np.concatenate([
        correlations[::-1] * -1, 
        [0],
        correlations
    ])

    return np.array([
        [[variance, corr],
        [corr, variance]]
        for corr in correlations
    ]).round(4)

def sample_blob(N_more, N_less, rs, check, rows=3, cols=3, var=0.03, min_corr=0.02):
    mu = np.zeros(2)
    sigma = np.eye(2) * (var-0.01)
    sigmas = create_cov_matrix(n_locs=rows*cols, variance=var, min_corr=min_corr)

    random_state = check_random_state(rs)

    # Sample X
    X = random_state.multivariate_normal(mu, sigma, size=N_more)
    X_row = random_state.randint(rows, size=N_more)
    X_col = random_state.randint(cols, size=N_more)
    X[:, 0] += X_row
    X[:, 1] += X_col

    X_labels = X_row * cols + X_col

    if check:
        Y = random_state.multivariate_normal(mu, np.eye(2), size=N_less)
        Y_row = random_state.randint(rows, size=N_less)
        Y_col = random_state.randint(cols, size=N_less)

        locs = create_grid(rows, cols)
        for i, loc in enumerate(locs):
            tgt_row, tgt_col = loc
            L = np.linalg.cholesky(sigmas[i])

            mask = (Y_row == tgt_row) & (Y_col == tgt_col)
            Y[mask] = Y[mask] @ L + loc
    else:
        Y = random_state.multivariate_normal(mu, sigma, size=N_less)
        Y[:, 0] += random_state.randint(rows, size=N_less)
        Y[:, 1] += random_state.randint(cols, size=N_less)

    return X, Y, X_labels

@lru_cache(maxsize=1)
def _load_cifar10_test(path):
    transform_test = transforms.Compose([transforms.ToTensor(),])
    testset = datasets.CIFAR10(root=path, train=False, download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=len(testset), shuffle=False, num_workers=0)
    imgs, labels = next(iter(test_loader))
    return imgs.numpy(), labels.numpy()

@lru_cache(maxsize=1)
def _load_cifar10_train(path):
    transform_test = transforms.Compose([transforms.ToTensor(),])
    trainset = datasets.CIFAR10(root=path, train=True, download=True, transform=transform_test)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=len(trainset), shuffle=False, num_workers=0)
    imgs, labels = next(iter(train_loader))
    return imgs.numpy(), labels.numpy()

@lru_cache(maxsize=1)
def _load_cifar10_adv(path):
    if builtins.MODEL_ARCH == "Res18":
        adv_path = os.path.join(path, "Adv_cifar10_pgd_5_eps4_linf.npz")
    elif builtins.MODEL_ARCH == "WRN28":
        adv_path = os.path.join(path, "Adv_cifar10_pgd_5_eps4_linf_transfer_wrn28.npz")
    # adv_path = os.path.join(path, "Adv_cifar10_pgd_5_eps4_linf_transfer_swin.npz")
    # adv_path = os.path.join(path, "Adv_cifar10_pgd_5_eps4_linf_transfer_wrn70.npz")
    data = np.load(adv_path)
    adv = data['X_adv']
    original_labels = data['predicted_original_labels']
    predicted_labels = data['predicted_adv_labels']
    mask = (predicted_labels != original_labels)
    return adv[mask]

def load_cifar10(N_more, N_less, rs, check):
    random_state = check_random_state(rs)
    path = os.path.join(builtins.ROOT_PATH, "data", "cifar10")

    X_all, X_labels_all = _load_cifar10_test(path)
    X_indices = random_state.choice(len(X_all), size=min(N_more, len(X_all)), replace=False)
    X = X_all[X_indices]
    X_labels = X_labels_all[X_indices]

    if check:
        Y_all = _load_cifar10_adv(path)
        Y_indices = random_state.choice(len(Y_all), size=min(N_less, len(Y_all)), replace=False)
        Y = Y_all[Y_indices]
    else:
        Y_all, _ = _load_cifar10_train(path)
        Y_indices = random_state.choice(len(Y_all), size=min(N_less, len(Y_all)), replace=False)
        Y = Y_all[Y_indices]

    if N_more > len(X_all):
        log(f"Warning: N_more is greater than total number of {len(X_all)} samples")
    return X, Y, X_labels

@lru_cache(maxsize=1)
def _load_higgs_data():
    path = os.path.join(builtins.ROOT_PATH, "data", "HIGGS_TST.pckl")
    return pickle.load(open(path, "rb"))

def sample_higgs(N_more, N_less, rs, check):
    torch.manual_seed(rs)
    data = _load_higgs_data()
    
    if check:
        X, Y = data[0], data[1]
    else:
        tmp = data[0]
        n = len(tmp)
        # Efficiently sample only what we need
        indices = np.random.choice(n, N_more + N_less, replace=False)
        X = tmp[indices[:N_more]]
        Y = tmp[indices[N_more:]]

    return X[:N_more], Y[:N_less], None

def setup_time_log():
    builtins.MMDFUSE_TIME_LOG = 0
    builtins.MMDAGG_TIME_LOG = 0
    builtins.IT_TIME_LOG = 0
    builtins.C2ST_TIME_LOG = 0
    builtins.MMDM_TIME_LOG = 0
    builtins.RLTST_TIME_LOG = 0
    builtins.MMDDEEP_TIME_LOG = 0