import torch
import torch.nn as nn
import torch.optim as optim
import random
import os
import logging
from utils.model_loader import load_models_from_folder
from utils.config import config
from utils.data_loader import get_cifar100_data_half

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
OUTDIM = 100

class SingleValueNet(nn.Module):
    def __init__(self):
        super().__init__()
        # linear: 1-dimensional input → 1-dimensional output
        self.fc = nn.Linear(OUTDIM, 1)

    def forward(self, x):
        # x may come in as a scalar tensor; reshape to [batch, 1]
        x = x.view(-1, OUTDIM).float()
        return torch.sigmoid(self.fc(x))


def train_fingerprints(parameters_conf, data_conf, paths):
    # Hyperparameters
    fp_size = parameters_conf.get("fp_size")
    batch_size = parameters_conf.get("batch_size")
    learning_rate = parameters_conf.get("learning_rate")
    num_epochs = parameters_conf.get("num_epochs")
    fp_init = parameters_conf.get("initialization")
    subset_size = parameters_conf.get("subset_size")
    file_path = parameters_conf.get('fingerprints_dir')
    os.makedirs(file_path, exist_ok=True)

    # Paths to model groups
    checkpoints_dir = os.path.join(paths.get('trainsets_dir'), "Mine")
    path_first_group = os.path.join(checkpoints_dir, 'protected')
    path_second_group = os.path.join(checkpoints_dir, 'independent')
    path_protected_model = os.path.join(checkpoints_dir, 'target')

    # Load models
    first_group_models = load_models_from_folder(path_first_group, device=device)
    second_group_models = load_models_from_folder(path_second_group, device=device)
    protected_model = load_models_from_folder(path_protected_model, device=device)

    # Freeze models in both groups
    for model in first_group_models + second_group_models + protected_model:
        for param in model.parameters():
            param.requires_grad = False
    protected_model = protected_model[0]

    # Fingerprints Initialization
    num_fps = fp_size[0]
    if fp_init == "random":
        trainable_fps = nn.Parameter(torch.randn(*fp_size, device=device) * 0.1, requires_grad=True)
    elif fp_init == "data":
        train_dataset_half, _, _, _ = get_cifar100_data_half("first")
        indices = torch.randperm(len(train_dataset_half))[:num_fps].tolist()
        fp_images = [train_dataset_half[i][0] for i in indices]
        fp_images = torch.stack(fp_images).to(device)
        trainable_fps = nn.Parameter(fp_images.clone(), requires_grad=True)
    else:
        raise ValueError("Unsupported fingerprint initialization method")

    # Initializing fingerprint nets
    fingerprint_nets = nn.ModuleList([SingleValueNet().to(device) for _ in range(num_fps)])

    optimizer = optim.Adam([trainable_fps]+list(fingerprint_nets.parameters()), lr=learning_rate)

    num_batches = num_fps // batch_size

    for epoch in range(num_epochs):
        total_loss, total_loss_1, total_loss_2 = 0.0, 0.0, 0.0

        perm = torch.randperm(num_fps)
        batches = [perm[i * batch_size:(i + 1) * batch_size] for i in range(num_batches)]
        # Process each mini-batch
        # noisy_trainable_fps = trainable_fps + 0.005 * torch.randn_like(trainable_fps)
        for batch in batches:
            batch_inputs = trainable_fps[batch].to(device)
            batch_fp_nets = [fingerprint_nets[i] for i in batch]
            # batch_inputs = torch.sigmoid(trainable_fps[batch].to(device)) * 2.0 - 1.0  # shape: (B, C, H, W)
            optimizer.zero_grad()

            # 1) Randomly pick subset_size indices from each group
            first_group_num = len(first_group_models)
            second_group_num = len(second_group_models)
            first_indices = random.sample(range(first_group_num), min(subset_size, first_group_num))
            second_indices = random.sample(range(second_group_num), min(subset_size, second_group_num))

            # 2) Run only those sampled models
            first_group_outputs = [first_group_models[i](batch_inputs) for i in first_indices]
            second_group_outputs = [second_group_models[i](batch_inputs) for i in second_indices]

###############Protected Model########
            protected_model_preds = torch.softmax(protected_model(batch_inputs),dim=1)

###############First Group########
            # Stack predictions
            first_group_preds = [torch.softmax(output, dim=1) for output in first_group_outputs]
            first_stacked_preds = torch.stack(first_group_preds, dim=0)  # [N, B, C]

            preds = first_stacked_preds.squeeze(-1)  # → [N, B]
            outs = []
            for i, net in enumerate(batch_fp_nets):
                # take the N outputs for fingerprint i, shape [N]
                xi = preds[:, i].unsqueeze(1)  # → shape [N, 1]
                yi = net(xi)  # → shape [N, 1], after sigmoid
                outs.append(yi)
            # now stack them back into the same [N, B, 1] layout
            first_stacked_preds = torch.stack(outs, dim=1)  # → [N, B, 1]

###############second group
            second_group_preds = [torch.softmax(output, dim=1) for output in second_group_outputs]
            second_stacked_preds = torch.stack(second_group_preds, dim=0)

            preds = second_stacked_preds.squeeze(-1)  # → [N, B]
            outs = []
            for i, net in enumerate(batch_fp_nets):
                # take the N outputs for fingerprint i, shape [N]
                xi = preds[:, i].unsqueeze(1)  # → shape [N, 1]
                yi = net(xi)  # → shape [N, 1], after sigmoid
                outs.append(yi)
            # now stack them back into the same [N, B, 1] layout
            second_stacked_preds = torch.stack(outs, dim=1)  # → [N, B, 1]

            # Create target tensors
            ones = torch.ones_like(first_stacked_preds)
            zeros = torch.zeros_like(second_stacked_preds)
            ones_protected = torch.ones_like(protected_model_preds)
            # Compute binary loss
            loss_0 = nn.functional.binary_cross_entropy(protected_model_preds, ones_protected)
            loss_1 = nn.functional.binary_cross_entropy(first_stacked_preds, ones)  # Encourage first group → 1
            loss_2 = nn.functional.binary_cross_entropy(second_stacked_preds, zeros)  # Encourage second group → 0

            # Total loss
            batch_loss = 0.5 * loss_0 + loss_1 + 1.5 * loss_2

            batch_loss.backward()
            optimizer.step()

            with torch.no_grad():
                trainable_fps.clamp_(-10.0, 10.0)
            total_loss += batch_loss.item() * batch_inputs.size(0)
            total_loss_1 += loss_1.item() * batch_inputs.size(0)
            total_loss_2 += loss_2.item() * batch_inputs.size(0)

        avg_loss = total_loss / num_fps
        avg_loss_1 = total_loss_1 / num_fps
        avg_loss_2 = total_loss_2 / num_fps
        logger.info(f"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_loss:.4f}, loss_1: {avg_loss_1}, loss_2: {avg_loss_2}.")

    # 5) save everything in one file
    nets_state_dicts = [net.state_dict() for net in fingerprint_nets]
    torch.save({
        'fingerprints': trainable_fps.detach().cpu(),
        'nets_state_dicts': nets_state_dicts,
    }, os.path.join(file_path, f"EffiFP_fingerprints.pt"))

    logger.info(f"Saved {len(trainable_fps)} fingerprints + nets to {file_path}")



if __name__ == "__main__":
    parameters_conf = config.get("PTAFP")
    data_conf = config.get("data")
    paths_conf = config.get("modelsets_paths")

    # Train fingerprints first
    train_fingerprints(parameters_conf, data_conf, paths_conf)



