
import torch
import torchvision
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
import os
from lightly.loss import NegativeCosineSimilarity


def pca_reduce(Y, n_components):
    n_samples = Y.shape[0]
    batch_size = min(50000, n_samples)
    Y_mean = Y.mean(dim=0)
    Y_centered = Y - Y_mean

    # Initialize covariance matrix
    covariance_matrix = torch.zeros(
        (Y.shape[1], Y.shape[1]), dtype=torch.float32, device="cuda"
    )

    # Batch processing for covariance matrix
    for start in range(0, n_samples, batch_size):
        end = min(start + batch_size, n_samples)
        batch = Y_centered[start:end].cuda()
        covariance_matrix += torch.mm(batch.t(), batch)

    covariance_matrix /= n_samples - 1

    # Compute eigenvectors
    _, eigenvectors = torch.linalg.eigh(covariance_matrix, UPLO="U")
    principal_components = eigenvectors[:, -n_components:]

    # Batch processing for reduced data
    Y_reduced = torch.zeros(
        (n_samples, n_components), dtype=torch.float32, device="cpu"
    )
    for start in range(0, n_samples, batch_size):
        end = min(start + batch_size, n_samples)
        batch = Y_centered[start:end].cuda()
        Y_reduced[start:end] = torch.mm(batch, principal_components)

    # Normalize
    Y_reduced = normalize_features(Y_reduced)

    return Y_reduced


def normalize_features(Y):
    n_samples = Y.shape[0]
    batch_size = min(50000, n_samples)
    dim = Y.shape[1]
    global_sum = torch.zeros(dim, dtype=torch.float32, device="cuda")
    global_sqr_sum = torch.zeros(dim, dtype=torch.float32, device="cuda")

    for start in range(0, n_samples, batch_size):
        end = min(start + batch_size, n_samples)
        batch = Y[start:end].cuda()
        global_sum += batch.sum(dim=0)
        global_sqr_sum += (batch**2).sum(dim=0)

    global_mean = global_sum / n_samples
    global_var = (global_sqr_sum / n_samples) - (global_mean**2)
    global_std = torch.sqrt(global_var)

    Y_normalized = (Y - global_mean.cpu()) / global_std.cpu()

    return Y_normalized


def convert(prior_model, dataset, size=32):

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=256,
        shuffle=False,
        num_workers=8,
    )

    with torch.no_grad():
        Y = []
        observer = prior_model.eval().cuda()
        for x, _ in tqdm(data_loader):
            y = observer(x.cuda(non_blocking=True)).float().detach().cpu()
            Y.append(y)
        Y = torch.cat(Y, dim=0)
    return Y



def train_matric_x(dataset, prior_model_1, prior_model_2):
    ## optimize c
    C_matrix = nn.Linear(512, 512)
    C_matrix = C_matrix.to('cuda')

    prior_model_1 = prior_model_1.to('cuda')
    prior_model_2 = prior_model_2.to('cuda')

    optimizer = optim.AdamW(C_matrix.parameters(), lr=0.001, weight_decay=0.01)

    criterion = NegativeCosineSimilarity().to('cuda')

    num_epochs = 20
    for epoch in tqdm(range(num_epochs)):
        running_loss = 0.0
        total_samples = 0
        trainloader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)

        for inputs, labels in trainloader:
            inputs = inputs.to('cuda')
            with torch.no_grad():
                features_A = prior_model_1(inputs)
                features_B = prior_model_2(inputs)

            outputs = C_matrix(features_A)

            loss = criterion(outputs, features_B)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            total_samples += inputs.size(0)

        epoch_loss = running_loss / total_samples

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')
    return C_matrix