import numpy as np
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
import wandb
import torch
import copy
import sys
from torch.utils.data import Dataset, DataLoader
import pickle
import os
import random
from typing import List, Set, Optional
import train_svhn_folder.train_svnh as train
import train_california.train_california as train_california
import train_ames.train_ames as train_ames
import torch.nn as nn
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import MinMaxScaler
from sklearn.impute import SimpleImputer
import pandas as pd
from torch.utils.data import random_split
from torchvision.datasets import CIFAR10
from model import *  # provides ResNetv2_rej


def images_std_mean(dataset):
    mean = {
        'cifar10': (0.4914, 0.4822, 0.4465),
        'cifar10H': (0.4914, 0.4822, 0.4465),
        'cifar100': (0.5071, 0.4867, 0.4408),
        'mnist': (0.1307,),
        'pascal': (0.485, 0.456, 0.406),
        'svhn': (0.4377, 0.4438, 0.4728),
    }

    std = {
        'cifar10': (0.2023, 0.1994, 0.2010),
        'cifar10H': (0.2023, 0.1994, 0.2010),
        'cifar100': (0.2675, 0.2565, 0.2761),
        'mnist': (0.3081,),
        'pascal': (0.229, 0.224, 0.225),
        'svhn': (0.1980, 0.2010, 0.1970)
    }
    return mean[dataset], std[dataset]


class CIFAR100WithAgents(torchvision.datasets.CIFAR100):
    def __init__(self, *args, agents=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.agents = agents

    def __getitem__(self, index):
        image, label = super().__getitem__(index)  # Get the image and label
        if self.agent_mode:
            agent = self.agents[:, index]
            return image, label, agent  # Return image, label, and agent
        else:
            return image, label

class SVHNWithAgents(torchvision.datasets.SVHN):
    def __init__(self, *args, agents=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.agents = agents

    def __getitem__(self, index):
        image, label = super().__getitem__(index)  # Get the image and label
        if self.agent_mode:
            agent = self.agents[:, index]
            return image, label, agent  # Return image, label, and agent
        else:
            return image, label

class CaliforniaHousingDataset(Dataset):
    def __init__(self, features, targets, agents):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)
        self.agents = agents
    def __len__(self):
        return len(self.features)
    def __getitem__(self, idx):
        if self.agent_mode:
            agent = self.agents[:, idx]
            return self.features[idx], self.targets[idx].squeeze(), agent
        else:
            return self.features[idx], self.targets[idx]

class AmesDataset(Dataset):
    def __init__(self, features, targets, agents):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)
        self.agents = agents
    def __len__(self):
        return len(self.features)
    def __getitem__(self, idx):
        if self.agent_mode:
            agent = self.agents[:, idx]
            return self.features[idx], self.targets[idx].squeeze(), agent
        else:
            return self.features[idx], self.targets[idx]

class CIFAR10WithAgents(torchvision.datasets.CIFAR10):
    def __init__(self, *args, agents=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.agents = agents

    def __getitem__(self, index):
        image, label = super().__getitem__(index)  # Get the image and label
        if self.agent_mode:
            agent = self.agents[:, index]
            return image, label, agent  # Return image, label, and agent
        else:
            return image, label

class CIFAR10HWithAgents(Dataset):
    def __init__(self, features, targets, agents):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)
        self.agents = agents
    def __len__(self):
        return len(self.features)
    def __getitem__(self, idx):
        if self.agent_mode:
            agent = self.agents[:, idx]
            return self.features[idx], self.targets[idx].squeeze(), agent
        else:
            return self.features[idx], self.targets[idx]


def processing(batch_size, batch_size_val, overfit=False, args=None, predictor=None, seed=42):
    # specific seed processing
    cpu_rng_state = torch.get_rng_state()
    gpu_rng_state = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
    state = np.random.get_state()

    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # 1. Load MNIST dataset
    print('\n[Phase 1] : Data Preparation')
    if args.task == 'classification':
        mean, std = images_std_mean(args.dataset)
    if 'cifar' in  args.dataset:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])  # meanstd transformation
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    elif args.dataset == 'svhn':
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

    if (args.dataset == 'cifar100'):
        print("| Preparing CIFAR-100 dataset...")
        train_dataset = CIFAR100WithAgents(root='./data', train=True, download=True, transform=transform_train)
        test_dataset = CIFAR100WithAgents(root='./data', train=False, download=False, transform=transform_test)
        num_classes = 100
    elif (args.dataset == 'svhn'):
        print("| Preparing SVHN dataset...")
        train_dataset = SVHNWithAgents(root='./data/SVHN/', split='train', download=True, transform=transform_train)
        train_dataset.targets = train_dataset.labels
        test_dataset = SVHNWithAgents(root='./data/SVHN/', split='test', download=True, transform=transform_test)
        test_dataset.targets = test_dataset.labels
        num_classes = 10
    elif (args.dataset == 'california'):
        print("| Preparing California dataset...")
        data = fetch_california_housing()
        X, y = data.data, data.target
        y_r = y.reshape(-1, 1)
        X_scaler = MinMaxScaler(feature_range=(0, 1))
        y_scaler = MinMaxScaler(feature_range=(0, 1))
        X_scaled = X_scaler.fit_transform(X)
        y_scaled = y_scaler.fit_transform(y_r)
        X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_scaled, test_size=0.2, random_state=42)
        train_dataset = CaliforniaHousingDataset(X_train, y_train, None)
        test_dataset = CaliforniaHousingDataset(X_test, y_test, None)
        num_classes = 1
    elif (args.dataset == 'ames'):
        print("| Preparing AMES dataset...")
        ames = fetch_openml(name="house_prices", version=1, as_frame=True)
        df = ames.frame.copy()
        # Keep target and drop row‐ID
        df = df.drop(columns=["Id"])
        df = df.dropna(subset=["SalePrice"])
        # Separate target
        y = df["SalePrice"].values.reshape(-1, 1)
        X_df = df.drop(columns=["SalePrice"])
        # 1a) Impute numeric features with median
        num_cols = X_df.select_dtypes(include=["number"]).columns
        imp = SimpleImputer(strategy="median")
        X_df[num_cols] = imp.fit_transform(X_df[num_cols])
        # 1b) Fill missing categoricals then one-hot encode
        cat_cols = X_df.select_dtypes(include=["object", "category"]).columns
        X_df[cat_cols] = X_df[cat_cols].fillna("Missing")
        X_df = pd.get_dummies(X_df, columns=cat_cols, drop_first=True)
        print(f"Using {X_df.shape[1]} features (after one-hot)")
        ##############################################################################
        # 2) Scale X, y to [0,1]
        ##############################################################################
        X_scaler = MinMaxScaler()
        y_scaler = MinMaxScaler()
        X_scaled = X_scaler.fit_transform(X_df)
        y_scaled = y_scaler.fit_transform(y)
        ##############################################################################
        # 3) Train/Test split
        ##############################################################################
        X_train, X_test, y_train, y_test = train_test_split(
            X_scaled, y_scaled, test_size=0.2, random_state=42
        )
        train_dataset = AmesDataset(X_train, y_train, None)
        test_dataset = AmesDataset(X_test, y_test, None)
        num_classes = 1

    elif (args.dataset == 'cifar10'):
        print("| Preparing CIFAR-10 dataset...")
        train_dataset = CIFAR10WithAgents(root='./data', train=True, download=True, transform=transform_train)
        test_dataset = CIFAR10WithAgents(root='./data', train=False, download=False, transform=transform_test)
        num_classes = 10
    elif (args.dataset == 'cifar10H'):
        print("| Preparing CIFAR10H dataset...")
        # 2) Load the raw CIFAR-10 *test* split
        base = CIFAR10(root="data/",
                       train=False,
                       download=True,
                       transform=transform_test)
        # stack into tensors
        X = torch.stack([img for img, _ in base])  # (10000,3,32,32)
        y = torch.tensor([lbl for _, lbl in base], dtype=torch.long)  # (10000,)
        agents = torch.argmax(torch.tensor(np.load('data/cifar10h/cifar10h-probs.npy')), dim=1)[:,None]
        # 3) Compute sizes and shuffle indices
        torch.manual_seed(42)
        n = X.size(0)
        n_train = int(0.8 * n)
        perm = torch.randperm(n)

        train_idx = perm[:n_train]
        test_idx = perm[n_train:]

        # 4) Index into X, y, agents
        X_train = X[train_idx]
        y_train = y[train_idx]
        agents_train = agents[train_idx]

        X_test = X[test_idx]
        y_test = y[test_idx]
        agents_test = agents[test_idx]

        # 5) Create two independent datasets
        train_dataset = CIFAR10HWithAgents(X_train, y_train, agents_train)
        test_dataset = CIFAR10HWithAgents(X_test, y_test, agents_test)
        num_classes = 1
    if overfit:
        train_dataset = test_dataset
        print('overfitting on test')

    # 2. Prepare Agents
    if args.n_points !=-1:
        random_arrange_tr = np.random.permutation(args.n_points)
        random_arrange_val = np.random.permutation(args.n_points)
    else:
        random_arrange_tr = np.asarray(list(range(len(train_dataset))))
        random_arrange_val = np.asarray(list(range(len(test_dataset))))

    if args.dataset == 'cifar100':
        if predictor is not None:
            agents_train, dict_train = agents_generator_final(train_dataset, state='Tr', num_classes=num_classes,
                                                              num_agents=args.n_agents,
                                                              predictor=predictor, args=args, random_subset=random_arrange_tr)
            agents_test, dict_test = agents_generator_final(test_dataset, state='Val', num_classes=num_classes,
                                                            num_agents=args.n_agents,
                                                            predictor=predictor, args=args, random_subset=random_arrange_val)
            max_loss = 1

    elif args.dataset == 'svhn':
        agents_train, dict_train = agent_svhn(train_dataset, args)
        agents_test, dict_test = agent_svhn(test_dataset, args)
        max_loss = 1

    elif args.dataset == 'california':
        agents_train, dict_train, max_loss_train = agent_california(train_dataset, args)
        agents_test, dict_test, max_loss_test = agent_california(test_dataset, args)
        max_loss = max(max_loss_train, max_loss_test)

    elif args.dataset == 'ames':
        agents_train, dict_train, max_loss_train = agent_ames(train_dataset, args)
        agents_test, dict_test, max_loss_test = agent_ames(test_dataset, args)
        max_loss = max(max_loss_train, max_loss_test)

    elif args.dataset == 'cifar10':
        agents_train, dict_train = agents_generator_cifar10(train_dataset, state='Tr', num_classes=num_classes,
                                                              num_agents=args.n_agents,
                                                              predictor=predictor, args=args, random_subset=random_arrange_tr)
        agents_test, dict_test = agents_generator_cifar10(test_dataset, state='Val', num_classes=num_classes,
                                                            num_agents=args.n_agents,
                                                            predictor=predictor, args=args, random_subset=random_arrange_val)
        max_loss = 1
    elif args.dataset == 'cifar10H':
        agents_train, max_loss_train = train_dataset.agents.T, 1
        acc_agent = torch.mean((agents_train[:,0]==train_dataset.targets)*1.0)
        dict_train = {'acc_list_full': acc_agent, 'knowledge_list': acc_agent, 'acc_list_subset': acc_agent}
        agents_test, max_loss_test = test_dataset.agents.T, 1
        acc_agent = torch.mean((agents_test[:, 0] == test_dataset.targets) * 1.0)
        dict_test = {'acc_list_full': acc_agent, 'knowledge_list': acc_agent, 'acc_list_subset': acc_agent}
        max_loss = 1

    # Add agents to dataset
    train_dataset.agents = agents_train
    train_dataset.agent_mode = True
    test_dataset.agents = agents_test
    test_dataset.agent_mode = True


    if args.subset_test:
        test_dataset_subset = torch.utils.data.Subset(test_dataset, random_arrange_val)
        test_loader_subset = DataLoader(dataset=test_dataset_subset, batch_size=batch_size_val, shuffle=False)
    else:
        test_dataset.agents = agents_test

    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size_val, shuffle=False)


    if not args.subset_test:
        test_loader_subset = test_loader

    # Restore the RNG states after the function execution
    torch.set_rng_state(cpu_rng_state)
    if torch.cuda.is_available():
        torch.cuda.set_rng_state_all(gpu_rng_state)
    np.random.set_state(state)
    return train_loader, test_loader_subset, test_loader, dict_train, dict_test, max_loss

import random
from typing import List, Set, Optional

import random


def generate_numbers(N, k, seed=None):
    if seed is not None:
        random.seed(seed)  # Set the random seed for reproducibility

    if k * (k - 1) // 2 * (N // 10) > N:  # Quick feasibility check
        raise ValueError("N is too small to distribute numbers with given constraints.")

    # Define gap limits
    min_gap = max(1, N // 20)
    max_gap = max(min_gap + 1, N // 10)

    # Start by picking x_k (the smallest value), ensuring space for others
    min_xk = max(1, N // k - max_gap * (k - 1))
    max_xk = min(N // k, N - (k - 1) * max_gap)

    x_k = random.randint(min_xk, max_xk)

    numbers = [x_k]

    for _ in range(k - 1):
        next_x = numbers[-1] + random.randint(min_gap, max_gap)  # Ensure the gap is within the specified range
        numbers.append(next_x)

    # Adjust last number to make the sum exactly N
    numbers[-1] += N - sum(numbers)

    # Reverse the list to maintain descending order
    return numbers[::-1]




def generate_agent_knowledge_with_bad_agent_subset(
        num_classes: int,
        num_agents: int,
        overlap: int,
        overlap_probability: float = 0.4,
        bad_agent_fraction: float = 0.1,  # Fraction of total classes for the bad agent.
        subset_fraction: float = 0.2,  # Fraction of the union of good agents' classes for the subset agent.
        seed: Optional[int] = 42,
        not_assigned_fraction: float = 0.08,
) -> List[List[int]]:
    """
    Generate knowledge sets for each agent with two special roles:

      - The penultimate agent (index num_agents-2) is the bad agent and receives
        only a small fraction of the total classes.
      - The last agent (index num_agents-1) is the subset agent; its knowledge is a random
        subset of the union of the knowledge of the "good" agents (all agents except the
        bad agent and the subset agent).

    All classes are initially shuffled and distributed so that every class appears in
    at least one agent's knowledge. Optionally, extra overlaps may be introduced among the
    good agents.

    In the special case of num_agents == 2, there are no "good" agents, so the subset agent
    will sample from the bad agent's knowledge.

    :param num_classes: Total number of unique classes.
    :param num_agents: Total number of agents. (Note: to have distinct bad and subset agents,
                       num_agents should ideally be >= 3.)
    :param overlap: Maximum number of extra overlaps to add among good agents.
    :param overlap_probability: Probability (between 0 and 1) to add extra overlaps among good agents.
    :param bad_agent_fraction: Fraction of total classes assigned to the bad agent.
    :param subset_fraction: Fraction of the union of the good agents' classes to assign to the subset agent.
    :param seed: Optional seed for reproducibility.
    :return: A list of sorted lists of class indices representing each agent's knowledge.
    """
    # Validate parameters.
    if not (0 <= overlap_probability <= 1):
        raise ValueError("overlap_probability must be between 0 and 1.")
    if not (0 < bad_agent_fraction < 1):
        raise ValueError("bad_agent_fraction must be between 0 and 1 (exclusive).")
    if not (0 < subset_fraction < 1):
        raise ValueError("subset_fraction must be between 0 and 1 (exclusive).")
    if num_agents < 2:
        raise ValueError("There must be at least 2 agents to have both a bad and a subset agent.")
    if num_classes < 1:
        raise ValueError("There must be at least one class.")

    rnd = random.Random(seed)

    # Initialize a list to hold each agent's knowledge as a set.
    agent_knowledge: List[Set[int]] = [set() for _ in range(num_agents)]

    # Create a pool of all classes and shuffle for randomness.
    all_classes = list(range(num_classes))
    rnd.shuffle(all_classes)
    all_classes = all_classes[:int(num_classes * (1 - not_assigned_fraction))]

    # --- 1. Assign classes to the bad agent (penultimate agent) ---
    bad_agent_index = num_agents - 2
    num_bad_classes = max(1, int(num_classes * bad_agent_fraction))
    bad_agent_classes = set(all_classes[:num_bad_classes])
    agent_knowledge[bad_agent_index].update(bad_agent_classes)

    # --- 2. Distribute remaining classes among the "good" agents ---
    # Define good agents as those that are neither the bad agent nor the subset agent (last agent).
    good_indices = [i for i in range(num_agents) if i not in {bad_agent_index, num_agents - 1}]

    # The remaining classes (after assigning the bad agent) will be distributed to good agents.
    remaining_classes = all_classes[num_bad_classes:]
    num_good = len(good_indices)
    if num_good > 0:
        # random distribution of classes among good agents
        # classes_per_good = generate_numbers(len(remaining_classes), num_good, seed=seed)
        classes_per_good = [len(remaining_classes) // num_good] * num_good
        for pos, agent in enumerate(good_indices):
            start_idx = pos * classes_per_good[pos] if pos==0 else sum(classes_per_good[:pos])
            # For the last good agent, assign all remaining classes.
            end_idx = start_idx + classes_per_good[pos] if pos < num_good - 1 else len(remaining_classes)
            agent_knowledge[agent].update(remaining_classes[start_idx:end_idx])

    # --- 3. Optionally introduce random overlaps among good agents ---
    for i in good_indices:
        if rnd.random() < overlap_probability:
            num_overlaps = rnd.randint(1, overlap) if overlap > 0 else 0
            for _ in range(num_overlaps):
                # Choose another good agent (with at least one class) to draw an overlapping class.
                possible = [j for j in good_indices if j != i and len(agent_knowledge[j]) > 0]
                if not possible:
                    break
                other = rnd.choice(possible)
                # Select a class from the other agent not already in the current agent.
                available = list(agent_knowledge[other] - agent_knowledge[i])
                if not available:
                    continue
                cls = rnd.choice(available)
                agent_knowledge[i].add(cls)

    # # --- 4. Handle any leftover classes not yet assigned ---
    # assigned = set()
    # for k in agent_knowledge:
    #     assigned.update(k)
    # leftover = set(all_classes) - assigned
    # for cls in leftover:
    #     # If there are good agents, assign the class to one randomly.
    #     if good_indices:
    #         agent = rnd.choice(good_indices)
    #     else:
    #         # Fallback: if no good agents exist, assign to the bad agent.
    #         agent = bad_agent_index
    #     agent_knowledge[agent].add(cls)

    # --- 5. Assign the subset agent (last agent) ---
    # The subset agent's knowledge will be a random subset of the union of the good agents' knowledge.
    if good_indices:
        union_good = set()
        for i in good_indices:
            union_good.update(agent_knowledge[i])
    else:
        # Fallback: if no good agents, use the bad agent's knowledge.
        union_good = set(agent_knowledge[bad_agent_index])

    # Determine how many classes the subset agent should get.
    num_subset = max(1, int(len(union_good) * subset_fraction))
    if len(union_good) < num_subset:
        num_subset = len(union_good)
    # If the union is empty, the subset agent ends up with an empty set.
    if union_good:
        subset_agent_classes = set(rnd.sample(list(union_good), num_subset))
    else:
        subset_agent_classes = set()
    agent_knowledge[num_agents - 1] = subset_agent_classes

    # Convert each agent's set to a sorted list.
    return [sorted(list(k)) for k in agent_knowledge]


def generate_agent_simple(num_classes: int, num_agents: int, overlap: int, overlap_agents: int):
    """
    Generate a list of knowledge lists for each agent.

    There are `num_classes` total classes (e.g. 100).
    Each agent i (0-indexed) is assigned a consecutive block of classes,
    computed as the range [i*d, i*d + L), where:

        d = floor((num_classes - overlap) / (num_agents + overlap_agents - 2))
        L = overlap + (overlap_agents - 1) * d

    This design ensures that any group of `overlap_agents` consecutive agents
    will share exactly `overlap` classes (i.e. the intersection of their blocks
    has length `overlap`).

    If an agent’s block would extend beyond the available classes, the window
    is adjusted to end at num_classes.

    Parameters:
      num_classes   : Total number of classes (e.g., 100)
      num_agents    : Total number of agents
      overlap       : The number of classes that are common to a group of agents
      overlap_agents: The number of consecutive agents that share the overlapping classes

    Returns:
      agent_knowledge: A list where each element is a list of class indices known to that agent.
    """
    if num_agents + overlap_agents - 2 <= 0:
        raise ValueError("Invalid combination of num_agents and overlap_agents")

    # Compute the step size d so that the final block fits within num_classes.
    d = (num_classes - overlap) // (num_agents + overlap_agents - 2)
    # Determine the length L of each agent's knowledge block.
    L = overlap + (overlap_agents - 1) * d

    agent_knowledge = []
    for i in range(0, num_agents):
        start = i * d
        end = start + L
        # If the window exceeds the available classes, adjust to use the last L classes.
        if end > num_classes:
            start = num_classes - L
            end = num_classes
        agent_knowledge.append(list(range(start, end)))

    return agent_knowledge





def generate_agent_knowledge_with_bad_agent(
    num_classes: int,
    num_agents: int,
    overlap: int,
    overlap_probability: float = 0.5,
    bad_agent_index: Optional[int] = None,
    bad_agent_fraction: float = 0.1,  # 5% of classes assigned to the bad agent
    seed: Optional[int] = 42  # Optional seed for reproducibility
) -> List[List[int]]:
    """
    Generate a list of knowledge sets for each agent, ensuring that one agent is very bad
    with significantly less knowledge, while others have distinct knowledge with optional
    controlled random overlaps. The function is deterministic when a seed is provided.

    :param num_classes: Total number of unique classes available.
    :param num_agents: Total number of agents (experts).
    :param overlap: Maximum number of classes that can overlap between any two agents.
    :param overlap_probability: Probability that an agent will have overlaps with others.
                                  Value between 0 and 1. Default is 0.5.
    :param bad_agent_index: Index of the agent to be the bad agent (0-based).
                            If None, a random agent will be selected as the bad agent.
    :param bad_agent_fraction: Fraction of classes to assign to the bad agent. Must be between 0 and 1.
    :param seed: Optional seed for the random number generator to ensure reproducibility.
    :return: A list where each element is a sorted list of class indices representing an agent's knowledge.
    """
    # Validate input parameters
    if not (0 <= overlap_probability <= 1):
        raise ValueError("overlap_probability must be between 0 and 1.")
    if not (0 < bad_agent_fraction < 1):
        raise ValueError("bad_agent_fraction must be between 0 and 1 (exclusive).")
    if num_agents < 1:
        raise ValueError("There must be at least one agent.")
    if num_classes < 1:
        raise ValueError("There must be at least one class.")

    # Initialize a local random number generator
    rnd = random.Random(seed)

    # Initialize a list to hold the knowledge sets for each agent
    agent_knowledge: List[Set[int]] = [set() for _ in range(num_agents)]

    # Determine the bad agent
    if bad_agent_index is None:
        bad_agent_index = rnd.randint(0, num_agents - 1)
    elif not (0 <= bad_agent_index < num_agents):
        raise ValueError("bad_agent_index must be within the range of agents.")

    # Calculate the number of classes for the bad agent
    num_bad_classes = max(1, int(num_classes * bad_agent_fraction))  # Ensure at least 1 class

    # Create a pool of all classes and shuffle it for randomness
    all_classes = list(range(num_classes))
    rnd.shuffle(all_classes)

    # Assign classes to the bad agent first
    bad_agent_classes = set(all_classes[:num_bad_classes])
    agent_knowledge[bad_agent_index].update(bad_agent_classes)

    # Remaining classes after assigning to the bad agent
    remaining_classes = all_classes[num_bad_classes:]

    # Calculate the number of classes each good agent should have
    num_good_agents = num_agents - 1
    if num_good_agents == 0:
        # Only one agent, which is the bad agent
        return [sorted(list(knowledge)) for knowledge in agent_knowledge]

    classes_per_good_agent = len(remaining_classes) // num_good_agents

    # Assign unique classes to each good agent
    for i in range(num_agents):
        if i == bad_agent_index:
            continue  # Skip the bad agent
        start_idx = (i if i < bad_agent_index else i - 1) * classes_per_good_agent
        end_idx = start_idx + classes_per_good_agent
        # Ensure we don't go out of bounds for the last agent
        if i == num_agents - 1:
            end_idx = len(remaining_classes)
        agent_knowledge[i].update(remaining_classes[start_idx:end_idx])

    # Update the assigned classes
    assigned_classes = set()
    for knowledge in agent_knowledge:
        assigned_classes.update(knowledge)

    # Introduce random overlaps among good agents
    for i in range(num_agents):
        if i == bad_agent_index:
            continue  # Skip the bad agent
        # Decide randomly whether to introduce overlaps for this agent
        if rnd.random() < overlap_probability:
            # Determine how many overlaps to introduce (1 to 'overlap')
            num_overlaps = rnd.randint(1, overlap) if overlap > 0 else 0
            for _ in range(num_overlaps):
                # Choose a random agent to overlap with, excluding self and bad agent
                possible_agents = [
                    agent for agent in range(num_agents)
                    if agent != i and agent != bad_agent_index and len(agent_knowledge[agent]) > 0
                ]
                if not possible_agents:
                    break  # No agents to overlap with
                overlap_agent = rnd.choice(possible_agents)
                # Choose a random class from the overlap_agent's knowledge that isn't already in current agent
                available_classes = list(agent_knowledge[overlap_agent] - agent_knowledge[i])
                if not available_classes:
                    continue  # No new classes to overlap with this agent
                cls = rnd.choice(available_classes)
                # Add the class to current agent's knowledge
                agent_knowledge[i].add(cls)

    # Optionally, handle any remaining classes (if num_classes not perfectly divisible)
    remaining_classes_set = set(all_classes) - assigned_classes
    if remaining_classes_set:
        for cls in remaining_classes_set:
            # Assign to a random good agent
            good_agents = [agent for agent in range(num_agents) if agent != bad_agent_index]
            if not good_agents:
                break  # Only bad agent exists
            agent = rnd.choice(good_agents)
            agent_knowledge[agent].add(cls)

    # Convert each set to a sorted list
    agent_knowledge = [sorted(list(knowledge)) for knowledge in agent_knowledge]

    return agent_knowledge

def agent_cifar10H(dataset, args):
    prob = np.load('data/cifar10h/cifar10h-probs.npy')
    prediction


def agent_california(dataset, args):
    track_outputs = []
    dataset_copy = copy.deepcopy(dataset)
    dataset_copy.agent_mode = False
    loader = DataLoader(dataset=dataset_copy, batch_size=args.batch_size, shuffle=False)
    # load experts
    agents_model = []
    for j in range(args.n_agents):
        path = f'./train_california/experts/expert_{j + 1}_best.ckpt'
        abs_path = os.path.abspath(path)
        model = train_california.TinyRegressor().eval()
        print('path', path, flush=True)
        print('abs path', abs_path, flush=True)
        load = torch.load(abs_path, weights_only=True)
        print('load expert', flush=True)
        model.load_state_dict(load)
        tmp = copy.deepcopy(model)
        agents_model.append(tmp)

    with torch.no_grad():
        predictor_outputs = []
        for j in range(args.n_agents):
            agents_model[j].to(args.device)
            track_outputs = []
            for i, (images, labels) in enumerate(loader):
                images, labels = images.to(args.device), labels.to(args.device)
                outputs = agents_model[j](images)
                # Track
                track_outputs.extend(outputs)
            predictor_outputs.append(torch.stack(track_outputs).T.cpu().numpy())
    predictions_agents = np.concatenate(predictor_outputs, axis=0)
    #RMSE
    t = (dataset.targets.squeeze()).repeat(predictions_agents.shape[0], 1)
    accuracy_agents = nn.MSELoss(reduce=False)(torch.tensor(predictions_agents), t).mean(axis=1).sqrt().tolist()
    max_loss = torch.max(nn.MSELoss(reduce=False)(torch.tensor(predictions_agents), t))

    return predictions_agents, {'acc_list_full': accuracy_agents, 'knowledge_list': accuracy_agents, 'acc_list_subset': accuracy_agents}, max_loss


def agent_ames(dataset, args):
    track_outputs = []
    dataset_copy = copy.deepcopy(dataset)
    dataset_copy.agent_mode = False
    loader = DataLoader(dataset=dataset_copy, batch_size=args.batch_size, shuffle=False)
    # load experts
    agents_model = []
    for j in range(args.n_agents):
        path = f'./train_ames/experts/expert_{j + 1}_best.ckpt'
        abs_path = os.path.abspath(path)
        model = train_ames.TinyRegressor(261).eval()
        print('path', path, flush=True)
        print('abs path', abs_path, flush=True)
        load = torch.load(abs_path, weights_only=True)
        print('load expert', flush=True)
        model.load_state_dict(load)
        tmp = copy.deepcopy(model)
        agents_model.append(tmp)

    with torch.no_grad():
        predictor_outputs = []
        for j in range(args.n_agents):
            agents_model[j].to(args.device)
            track_outputs = []
            for i, (images, labels) in enumerate(loader):
                images, labels = images.to(args.device), labels.to(args.device)
                outputs = agents_model[j](images)
                # Track
                track_outputs.extend(outputs)
            predictor_outputs.append(torch.stack(track_outputs).T.cpu().numpy())
    predictions_agents = np.concatenate(predictor_outputs, axis=0)
    #RMSE
    t = (dataset.targets.squeeze()).repeat(predictions_agents.shape[0], 1)
    accuracy_agents = nn.MSELoss(reduce=False)(torch.tensor(predictions_agents), t).mean(axis=1).sqrt().tolist()
    max_loss = torch.max(nn.MSELoss(reduce=False)(torch.tensor(predictions_agents), t))

    return predictions_agents, {'acc_list_full': accuracy_agents, 'knowledge_list': accuracy_agents, 'acc_list_subset': accuracy_agents}, max_loss


def agent_svhn(dataset, args):
    track_outputs = []
    dataset_copy = copy.deepcopy(dataset)
    dataset_copy.agent_mode = False
    loader = DataLoader(dataset=dataset_copy, batch_size=args.batch_size_eval, shuffle=False)
    # load experts
    agents_model = []
    for j in range(args.n_agents):
        path = f'train_svhn_folder/experts/expert_{j + 1}_best.pth'
        abs_path = os.path.abspath(path)
        # model = train.SimpleCNN(num_classes=10).eval()
        model = ResNetv2_rej(10, 10, dropout=0.0).eval()
        print('path', path, flush=True)
        print('abs path', abs_path, flush=True)
        load = torch.load(abs_path, weights_only=True)
        print('load expert', flush=True)
        model.load_state_dict(load)
        tmp = copy.deepcopy(model)
        agents_model.append(tmp)

    with torch.no_grad():
        predictor_outputs = []
        for j in range(args.n_agents):
            agents_model[j].to(args.device)
            track_outputs = []
            for i, (images, labels) in enumerate(loader):
                images, labels = images.to(args.device), labels.to(args.device)
                outputs = agents_model[j](images)
                # Track
                track_outputs.extend(torch.argmax(outputs, dim=1))
            predictor_outputs.append(torch.stack(track_outputs)[None,:].cpu().numpy())
    predictions_agents = np.concatenate(predictor_outputs, axis=0)
    accuracy_agents = (predictions_agents==dataset.labels).mean(axis=1).tolist()

    return predictions_agents, {'acc_list_full': accuracy_agents, 'knowledge_list': accuracy_agents, 'acc_list_subset': accuracy_agents}

def agents_generator_final(dataset, state='tr', num_classes=None, num_agents=None, predictor=None, args=None, random_subset=None, seed=42, p=0.06):
    state_rng = np.random.get_state()
    np.random.seed(seed)
    track_outputs = []
    dataset_copy = copy.deepcopy(dataset)
    dataset_copy.agent_mode = False
    loader = DataLoader(dataset=dataset_copy, batch_size=args.batch_size, shuffle=False)
    predictor.eval()
    with torch.no_grad():
        for i, (images, labels) in enumerate(loader):
            images, labels = images.to(args.device), labels.to(args.device)
            outputs = predictor(images)
            # Track
            track_outputs.append(torch.argmax(outputs, dim=1))
        predictor_outputs = torch.cat(track_outputs).cpu().numpy()

    labels = np.array(dataset.targets)
    accuracy_predictor_full = np.mean(predictor_outputs == labels)
    print(f"Full {state} Predictor_accuracy = {accuracy_predictor_full}", flush=True)
    accuracy_predictor_subset = np.mean(predictor_outputs[random_subset] == labels[random_subset])
    print(f"Subset {state} Predictor_accuracy = {accuracy_predictor_subset}", flush=True)

    overlap = num_classes//3
    choices = np.arange(num_classes)
    num_experts = num_agents - 1
    # agents_knowledge = generate_agent_knowledge_with_bad_agent_subset(len(choices), num_experts, overlap)
    agents_knowledge = generate_agent_simple(len(choices), num_experts, overlap, num_experts//2+1)
    knowledge_list = []
    for i, knowledge in enumerate(agents_knowledge):
        print(f"Expert_{i+1}_knowledge = {knowledge}", flush=True)
        knowledge_list.append(knowledge)
    # Use np.isin to check if the labels are in agent_0_knowledge
    agents_predictions = np.zeros((len(agents_knowledge), len(labels)))
    acc_list = []
    acc_list_subset = []
    for i in range(len(agents_knowledge)):
        # restricted_choices = choices[np.isin(choices, agents_knowledge[i], invert=True)]
        uniform_prediction = np.random.choice(np.arange(num_classes), size=len(labels))
        mask = np.isin(labels, agents_knowledge[i])
        random_value = np.random.rand(*mask.shape)
        # Step 3: Identify where to flip True to False
        flip_condition = (mask) & (random_value < p)
        # Step 4: Flip the selected True values to False
        mask[flip_condition] = False
        agents_predictions[i] = np.where(mask, labels, uniform_prediction)
        accuracy = np.mean(agents_predictions[i] == labels)
        accuracy_subset = np.mean(agents_predictions[i][random_subset] == labels[random_subset])
        # wandb.log({f"Agent_{i}_accuracy_{state}": accuracy})
        print(f"Expert_{i+1}_accuracy_{state} = {accuracy}", flush=True)
        print(f"Expert_{i + 1}_accuracy_Subset_{state} = {accuracy_subset}", flush=True)
        acc_list.append(accuracy)
        acc_list_subset.append(accuracy_subset)

    agents_predictions = np.concatenate((predictor_outputs[None,:], agents_predictions), axis=0)
    acc_list.insert(0, accuracy_predictor_full)
    acc_list_subset.insert(0, accuracy_predictor_subset)
    knowledge_list.insert(0, 'Predictor')
    np.random.set_state(state_rng)
    return agents_predictions, {'acc_list_full': acc_list, 'knowledge_list': knowledge_list, 'acc_list_subset': acc_list_subset}


def agents_generator_cifar10(dataset, state='tr', num_classes=None, num_agents=None, predictor=None, args=None, random_subset=None, seed=42, p=0.06):
    state_rng = np.random.get_state()
    np.random.seed(seed)
    labels = np.array(dataset.targets)

    overlap = num_classes//4
    choices = np.arange(num_classes)
    num_experts = num_agents
    # agents_knowledge = generate_agent_knowledge_with_bad_agent_subset(len(choices), num_experts, overlap)
    agents_knowledge = generate_agent_simple(len(choices), num_experts, overlap, num_experts//2+1)
    knowledge_list = []
    for i, knowledge in enumerate(agents_knowledge):
        print(f"Expert_{i+1}_knowledge = {knowledge}", flush=True)
        knowledge_list.append(knowledge)
    # Use np.isin to check if the labels are in agent_0_knowledge
    agents_predictions = np.zeros((len(agents_knowledge), len(labels)))
    acc_list = []
    acc_list_subset = []
    for i in range(len(agents_knowledge)):
        # restricted_choices = choices[np.isin(choices, agents_knowledge[i], invert=True)]
        uniform_prediction = np.random.choice(np.arange(num_classes), size=len(labels))
        mask = np.isin(labels, agents_knowledge[i])
        random_value = np.random.rand(*mask.shape)
        # Step 3: Identify where to flip True to False
        flip_condition = (mask) & (random_value < p)
        # Step 4: Flip the selected True values to False
        mask[flip_condition] = False
        agents_predictions[i] = np.where(mask, labels, uniform_prediction)
        accuracy = np.mean(agents_predictions[i] == labels)
        accuracy_subset = np.mean(agents_predictions[i][random_subset] == labels[random_subset])
        # wandb.log({f"Agent_{i}_accuracy_{state}": accuracy})
        print(f"Expert_{i+1}_accuracy_{state} = {accuracy}", flush=True)
        print(f"Expert_{i + 1}_accuracy_Subset_{state} = {accuracy_subset}", flush=True)
        acc_list.append(accuracy)
        acc_list_subset.append(accuracy_subset)

    np.random.set_state(state_rng)
    return agents_predictions, {'acc_list_full': acc_list, 'knowledge_list': knowledge_list, 'acc_list_subset': acc_list_subset}



def processing_multi_task(batch_size, batch_size_val, overfit=False, args=None, predictor=None):
    """Process multi-task data with reproducibility."""
    dataset_root = 'data'
    print('start loading data training', flush=True)
    dataset_tr = VOCDetectionDataset(root=dataset_root, year='2012', image_set='train',
                                     transforms=get_transform(train=True), agent=None)
    print('start loading data testing', flush=True)
    dataset_test = VOCDetectionDataset(root=dataset_root, year='2012', image_set='val',
                                       transforms=get_transform(train=False), agent=None)

    # Add agents to dataset
    dataset_tr.agent_mode = False
    dataset_test.agent_mode = False

    # Define subset indices for evaluation
    subset_size = args.n_points
    subset_indices = torch.arange(len(dataset_test))[:subset_size].tolist()

    data_loader = torch.utils.data.DataLoader(dataset_tr, batch_size=2, shuffle=False, collate_fn=collate_fn, pin_memory=False, num_workers=0)
    data_test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=2, shuffle=False, collate_fn=collate_fn, pin_memory=False, num_workers=0)

    agents_val, acc_val = generate_agents_pascal_specialized(data_test_loader, state='Val',
                                                                 num_classes=len(dataset_test.classes),
                                                                 args=args, id_samples=subset_indices)
    agents_tr, acc_tr = generate_agents_pascal_specialized(data_loader, state='Tr',
                                                               num_classes=len(dataset_tr.classes), args=args)

    train_dataset = VOCDetectionDataset(root=dataset_root, year='2012', image_set='train',
                                        transforms=get_transform(train=True), agent=agents_tr)
    test_dataset = VOCDetectionDataset(root=dataset_root, year='2012', image_set='val',
                                       transforms=get_transform(train=False), agent=agents_val)

    train_dataset.agent_mode = True
    test_dataset.agent_mode = True

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                               collate_fn=collate_fn, pin_memory=False, num_workers=0)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size_eval, shuffle=False,
                                              collate_fn=collate_fn, pin_memory=False, num_workers=0)

    if args.subset_test:
        eval_subset = SubsampleDataset(test_dataset, subset_indices)
        test_loader_subset = torch.utils.data.DataLoader(eval_subset, batch_size=args.batch_size_eval, shuffle=False,
                                                         collate_fn=collate_fn, pin_memory=False, num_workers=0)
    else:
        test_loader_subset = test_loader

    return train_loader, test_loader_subset, test_loader, acc_tr, acc_val


class SubsampleDataset(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # Map the requested index to the original dataset's index
        original_idx = self.indices[idx]
        return self.dataset[original_idx]



def generate_agents_pascal_specialized(data_loader, state='tr', num_classes=None, args=None, id_samples=None):
    device = args.device
    args.model = 'mobile'
    model = get_model(num_classes, args).to(device)
    model.load_state_dict(torch.load(f'Models_weights/pascal/{args.model}_best_model.pth'))
    model.eval()
    ###
    args.model = 'rcnn_full'
    expert1 = get_model(num_classes, args).to(device)
    expert1.load_state_dict(torch.load(f'Models_weights/pascal/{args.model}_animal_best_model.pth'))
    expert1.eval()
    ##
    args.model = 'rcnn_full'
    expert2 = get_model(num_classes, args).to(device)
    expert2.load_state_dict(torch.load(f'Models_weights/pascal/{args.model}_vehicle_best_model.pth'))
    expert2.eval()

    if not os.path.exists(f'Models_weights/pascal/{state}_model.pkl'):
        store_metric_model, model_prediction, target_store, map_model = evaluate(model, data_loader, device, mode='total') #check micro
        data_model = {'predictions': model_prediction, 'targets': target_store, 'map': map_model,
                      'store_metric_model': store_metric_model}
        with open(f'Models_weights/pascal/{state}_model.pkl', 'wb') as f:
            pickle.dump(data_model, f)
    else:
        with open(f'Models_weights/pascal/{state}_model.pkl', 'rb') as f:
            data_model = pickle.load(f)
        model_prediction = data_model['predictions']
        target_store = data_model['targets']
        map_model = data_model['map']
        store_metric_model = data_model['store_metric_model']
    if not os.path.exists(f'Models_weights/pascal/{state}_expert1_animal.pkl'):
        store_metric_expert1, expert1_prediction, target_store, map_expert1 = evaluate(expert1, data_loader, device, mode='total') #check micro
        data_expert1 = {'predictions': expert1_prediction, 'targets': target_store, 'map': map_expert1,
                        'store_metric_model': store_metric_expert1}
        with open(f'Models_weights/pascal/{state}_expert1_animal.pkl', 'wb') as f:
            pickle.dump(data_expert1, f)
    else:
        with open(f'Models_weights/pascal/{state}_expert1_animal.pkl', 'rb') as f:
            data_expert1 = pickle.load(f)
        expert1_prediction = data_expert1['predictions']
        target_store = data_expert1['targets']
        map_expert1 = data_expert1['map']
        store_metric_expert1 = data_expert1['store_metric_model']
    if not os.path.exists(f'Models_weights/pascal/{state}_expert2_vehicle.pkl'):
        store_metric_expert2, expert2_prediction, target_store, map_expert2 = evaluate(expert2, data_loader, device, mode='total') #check micro
        data_expert2 = {'predictions': expert2_prediction, 'targets': target_store, 'map': map_expert2, 'store_metric_model': store_metric_expert2}
        with open(f'Models_weights/pascal/{state}_expert2_vehicle.pkl', 'wb') as f:
            pickle.dump(data_expert2, f)
    else:
        with open(f'Models_weights/pascal/{state}_expert2_vehicle.pkl', 'rb') as f:
            data_expert2 = pickle.load(f)
        expert2_prediction = data_expert2['predictions']
        target_store = data_expert2['targets']
        map_expert2 = data_expert2['map']
        store_metric_expert2 = data_expert2['store_metric_model']

    if id_samples is not None:
        target_store = np.asarray(target_store)
        map_subset_model = compute_map(model_prediction[0,id_samples], target_store[id_samples], args)
        map_subset_expert1 = compute_map(expert1_prediction[0,id_samples], target_store[id_samples], args)
        map_subset_expert2 = compute_map(expert2_prediction[0,id_samples], target_store[id_samples], args)
        map_list = {f'full_map_{state}': [map_model, map_expert1, map_expert2],
                    f'subset_map{state}': [map_subset_model, map_subset_expert1, map_subset_expert2, id_samples]}
    else:
        map_list = {f'full_map_{state}': [map_model, map_expert1, map_expert2]}

    agent_predictions = np.concatenate((model_prediction, expert1_prediction, expert2_prediction), axis=0)
    return agent_predictions, map_list







