import time 
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import models  # If you want to use ResNet18, torchvision must be installed
from util_real import process_weights  # Your existing utility function
from mobilenetv2 import MobileNetV2
from util_real import upload_parameters, download_parameters, check_param_path, get_adj_from_neib
import itertools
import pandas as pd


def grid_search(Xs, ys, X_val, y_val, base_model, init_parameters,
                param_grid, refit=True):
    """
    Returns the best parameters and the model with those parameters set (requires refitting).
    """
    best_loss = np.inf
    losses = []
    scores = []

    for params in itertools.product(*param_grid.values()):
        param_dict = dict(zip(param_grid.keys(), params))
        # Dynamically set parameters on the model
        for key, value in param_dict.items():
            setattr(base_model, key, value)  # Directly assign attribute value to base_model
        base_model.set_parameters(init_parameters)
        base_model.adapt_refit(Xs, ys)

        _ = base_model.refit_loss(X_val, y_val)
        loss = base_model.loss_.squeeze()

        losses.append(loss)
        if np.min(loss) < best_loss:
            best_loss = np.min(loss)
            best_params = param_dict.copy()  # Save grid parameters
            best_weights = base_model.weights.copy()

    if refit:
        for key, value in best_params.items():
            setattr(base_model, key, value)  # Directly assign attribute value to base_model

    return losses, best_params, best_weights, base_model


def split_bucket_indices(Xs_all, ys_all, val_ratio=0.2, random_state=42):
    """
    Split sample indices for each bucket into train/valid sets.

    Args:
    - Xs_all: List of buckets from original data, only used to get length
    - val_ratio: Validation set ratio
    - random_state: Random seed (used only for random splitting)

    Returns:
    - Xs_train_all: List of training samples for each bucket
    - ys_train_all: List of training labels for each bucket
    - Xs_valid_all: List of validation samples for each bucket
    - ys_valid_all: List of validation labels for each bucket
    """
    np.random.seed(random_state)
    Xs_train_all = []
    ys_train_all = []
    Xs_valid_all = []
    ys_valid_all = []

    for X, y in zip(Xs_all, ys_all):
        n = len(X)
        indices = np.arange(n)
        np.random.shuffle(indices)
        split = int(n * (1 - val_ratio))
        train_idx = indices[:split]
        valid_idx = indices[split:]
        Xs_train_all.append(X[train_idx])
        Xs_valid_all.append(X[valid_idx])
        ys_train_all.append(y[train_idx])
        ys_valid_all.append(y[valid_idx])
    return Xs_train_all, ys_train_all, Xs_valid_all, ys_valid_all


class LogisticRegressionModel(nn.Module):
    """
    Simple logistic regression model.
    Suitable for input shape: [batch_size, input_dim]
    """
    def __init__(self, input_dim, num_classes=10):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        logits = self.linear(x)
        return logits


class LeNet5(nn.Module):
    """
    LeNet5, suitable for MNIST: input shape [batch_size, 1, 28, 28]
    """
    def __init__(self, n_classes=10):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, n_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))    # -> [batch_size, 6, 24, 24]
        x = self.pool(x)             # -> [batch_size, 6, 12, 12]
        x = F.relu(self.conv2(x))    # -> [batch_size, 16, 8, 8]
        x = self.pool(x)             # -> [batch_size, 16, 4, 4]
        x = x.view(x.size(0), -1)    # Flatten -> [batch_size, 16*4*4]
        x = F.relu(self.fc1(x))      # -> [batch_size, 120]
        x = F.relu(self.fc2(x))      # -> [batch_size, 84]
        x = self.fc3(x)              # -> [batch_size, n_classes]
        return x


class TwoLayerNet(nn.Module):
    def __init__(self, input_dim=512, num_classes=10):
        super(TwoLayerNet, self).__init__()
        hidden_dim = 256
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        return self.net(x)


def build_model(model_type, num_classes=10, input_dim=None, pretrained=False):
    """
    Return a model instance based on model_type.

    Args:
    - model_type (str): 'logistic', 'lenet5', 'resnet18', ...
    - num_classes (int): Number of output classes
    - input_dim (int): Required for logistic regression
    - pretrained (bool): For resnet18, whether to use pretrained weights
    """
    if model_type.lower() == 'logistic':
        if input_dim is None:
            raise ValueError("Input dimension must be specified for logistic regression.")
        return LogisticRegressionModel(input_dim=input_dim, num_classes=num_classes)

    elif model_type.lower() == 'lenet5':
        return LeNet5(n_classes=num_classes)

    elif model_type.lower() == 'resnet18':
        # To adapt to MNIST, either convert 1 channel to 3 or rewrite the first layer
        # Only a basic example here
        model = models.resnet18(pretrained=pretrained)
        # Replace the final layer: fully connected output => num_classes
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
        return model

    elif model_type.lower() == 'nn':
        return TwoLayerNet(input_dim=input_dim, num_classes=num_classes)

    else:
        raise ValueError(f"Unknown model_type: {model_type}")

###################################################
#  initialization function
###################################################
def init_weights(m):
    """
    Custom initialization function, e.g., use Xavier initialization for Conv2d and Linear layers.
    Set bias to 0 uniformly.
    """
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)


class Optimizer(nn.Module):
    def __init__(self, model_type='logistic', 
                 num_classes=10,        # Number of classes
                 lr=0.01,               # Learning rate
                 device='cpu', 
                 random_state=None,
                 input_dim=None,        # Required if using logistic regression
                 pretrained=False,    
                 custom_init=True,      # Whether to apply custom initialization
                 momentum=0.0           # Use plain gradient descent by default
                 ):
        super(Optimizer, self).__init__()
        
        if random_state is not None:
            np.random.seed(random_state)
            torch.manual_seed(random_state)
            if device == 'cuda':
                torch.cuda.manual_seed_all(random_state)
        
        # Build model based on model_type
        self.net = build_model(
            model_type=model_type,
            num_classes=num_classes,
            input_dim=input_dim,
            pretrained=pretrained)
        
        """
        Modified init
        """
        # Apply custom initialization if required and not pretrained
        if custom_init and not pretrained:
            self.net.apply(init_weights)
        
        # Define loss function and optimizer
        self.criterion = nn.CrossEntropyLoss()   
        self.optimizer = optim.SGD(self.parameters(), lr=lr, momentum=momentum)
        
        self.device = device
        self.to(device)  # Move model to specified device

    def forward(self, x):
        """
        Forward pass
        """
        return self.net(x)

    def get_parameters(self):
        return [p.detach().clone() for p in self.parameters()]

    def set_parameters(self, new_params):
        with torch.no_grad():
            for p, new_p in zip(self.parameters(), new_params):
                p.copy_(new_p)
                
    def set_learning_rates(self, new_lr):
        """
        Update learning rate of optimizer
        Args:
        - new_lr (float): New learning rate
        """
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = new_lr

    def fit_onestep(self, X, Y):
        """
        Perform one update step: forward + backward + parameter update
        """                
        loss = self.compute_loss(X,Y)
        loss.backward()
        self.optimizer.step()
        return loss.detach()
    
    def compute_gradients(self, X, Y):
        loss = self.compute_loss(X,Y)
        loss.backward()
        # Return list of gradients (same order as self.parameters())
        grad_list = [p.grad for p in self.parameters()]
        return grad_list
    
    def compute_loss(self,X,Y):
        self.optimizer.zero_grad()
        logits = self.forward(X)
        loss = self.criterion(logits, Y)
        # loss = self.criterion(logits, Y.squeeze().long())
        return loss
    
    def compute_gradients_loss(self, X, Y):
        loss = self.compute_loss(X,Y)
        loss.backward()
        # Return list of gradients (same order as self.parameters())
        grad_list = [p.grad for p in self.parameters()]
        return grad_list, loss.detach()
    
    def fit(self, X, Y, epochs=10, print_freq=100):
        """
        Train for multiple epochs
        """
        self.train()
        for epoch in range(epochs):
            running_loss = 0.0  # Accumulate loss on GPU
            start_time = time.time()
            loss_val = self.fit_onestep(X, Y)
            running_loss += loss_val  # Accumulate on GPU
            if (epoch+1) % print_freq == 0:
                # Copy average loss from GPU to CPU
                avg_loss = (running_loss / print_freq).item()
                t_elapsed = time.time() - start_time
                print(f"\rEpoch [{epoch+1}/{epochs}], Avg Loss: {avg_loss:.4f}, TC: {t_elapsed:3.2f}s", end='')                
            
    def predict_proba(self, X):
        """
        Output class probability distribution
        """
        self.eval()
        with torch.no_grad():
            logits = self.forward(X)
            proba = torch.softmax(logits, dim=1)  # Use torch.softmax
        return proba.cpu().numpy()

    def predict(self, X):
        """
        Output predicted class labels
        """
        proba = self.predict_proba(X)
        pred = np.argmax(proba, axis=1)
        return pred
    
    def evaluate(self, X_test, y_test):
        with torch.no_grad():
            # Forward pass to get logits
            outputs = self.forward(X_test)  # shape: (N, num_classes)
            # Compute cross-entropy loss (PyTorch's CrossEntropyLoss = log_softmax + NLLLoss)
            val = F.cross_entropy(outputs, y_test).item()
            # Get predicted labels
            preds = outputs.argmax(dim=1)  # shape: (N,)
            # Compute accuracy
            correct = (preds == y_test).sum().item()
            acc = correct / len(y_test)
        return acc, val
    
    """
    Newly added batch training function
    """
    def fit_batch(self, dataloader, epochs=10, print_freq=1):
        """
        Train over multiple epochs using DataLoader
        """
        self.train()
        n_batch = len(dataloader)
        for epoch in range(epochs):
            running_loss = 0.0  # Accumulate loss on GPU
            start_time = time.time()
            for (Xm, Ym) in dataloader:
                Xm, Ym = Xm.to(self.device), Ym.to(self.device)
                loss_val = self.fit_onestep(Xm, Ym)
                running_loss += loss_val  # Accumulate on GPU
            if (epoch+1) % print_freq == 0:
                loss_val = (running_loss / (n_batch * print_freq)).item()  # Sync once
                t_elapsed = time.time() - start_time
                remaining = t_elapsed * (epochs - epoch - 1)
                print(f"\rEpoch [{epoch+1}/{epochs}], Avg Loss: {loss_val:.4f}, TC: {t_elapsed:3.2f}s, ETA: {remaining/60:3.2f}min", end='')
                
    def predproba_batch(self, dataloader):
        """
        Predict class probability distribution for entire DataLoader
        """
        self.eval()
        all_proba = []
        with torch.no_grad():
            for Xm, _ in dataloader:
                Xm = Xm.to(self.device)
                logits = self.forward(Xm)
                proba = torch.softmax(logits, dim=1)
                all_proba.append(proba.cpu().numpy())
        return np.concatenate(all_proba, axis=0)

    def pred_batch(self, dataloader):
        """
        Predict class labels for entire DataLoader
        """
        proba = self.predproba_batch(dataloader)
        # For each sample, choose the class with highest probability
        return np.argmax(proba, axis=1)
    
class DFLOptimizerBase:
    def __init__(self, neighbors, lr, 
                 model_type='logistic',  # default to lenet5
                 n_classes=10,
                 n_workers=10,
                 epochs=10,
                 device='cpu', 
                 random_state=None,
                 input_dim=None,      # specify this if using logistic regression
                 pretrained=False,    # use pretrained model if using resnet18
                 custom_init=True     # whether to use custom initialization
                 ):
        """
        - neighbors: list of list, each worker's neighbor list
        - lr: list or scalar, learning rate for each worker
        - model_type: 'logistic', 'lenet5', or 'resnet18'
        """
        self.neighbors = neighbors
        self.lr = lr
        self.model_type = model_type
        self.n_classes = n_classes
        self.n_workers = n_workers
        self.epochs = epochs
        self.device = device
        self.random_state = random_state
        self.input_dim = input_dim
        self.pretrained = pretrained
        self.custom_init = custom_init
        
        
    def _initialize_models(self):
        # If self.lr is a scalar, then all workers share the same learning rate
        if isinstance(self.lr, (float, int)):
            lr_list = [self.lr] * self.n_workers
        else:
            lr_list = self.lr
        self.models_ = []
        for m in range(self.n_workers):
            model = Optimizer(
                model_type=self.model_type,
                num_classes=self.n_classes,
                lr=lr_list[m],
                device=self.device,
                random_state=self.random_state,
                input_dim=self.input_dim,
                pretrained=self.pretrained,
                custom_init=self.custom_init
            )
            self.models_.append(model)
        return self
    
    def _initialize_history(self):
        self.history_ = {'loss': [], 'acc': []}
    
    def set_learning_rates(self, new_lr):
        """
        Update the learning rate for each worker
        """
        for model, lr in zip(self.models_, new_lr):
            model.set_learning_rates(lr)
        return self
    
    def fit(self, Xs, ys, X_val=None, y_val=None, print_freq=100):
        self._initialize_models()
        self._initialize_history()
        self.refit(Xs, ys, X_val=X_val, y_val=y_val, print_freq=print_freq)
        return self
    
    def _aggregate(self, m=0):
        """
        Aggregate neighbor models' parameters by averaging
        """
        param_lists = [self.param_lists[i] for i in self.neighbors[m]]
        averaged_params = [
            torch.mean(torch.stack(params, dim=0), dim=0) 
            for params in zip(*param_lists)
        ]
        return averaged_params
        
    def _update(self, Xs, ys):
        """
        Perform one round of DFL updates for all workers (first aggregate, then fit_onestep)
        """
        loss_vals = []
        self.param_lists = self.get_parameters()
        for m, (Xm, ym) in enumerate(zip(Xs, ys)):
            # 1) Aggregate from neighbors
            averaged_params = self._aggregate(m=m)
            self.models_[m].set_parameters(averaged_params)
            # 2) Local update
            loss_val = self.models_[m].fit_onestep(Xm, ym)
            loss_vals.append(loss_val)

        # Convert list to tensor and compute weighted sum using broadcasting
        loss_tensor = torch.stack(loss_vals)  # shape: [n_workers]
        avg_loss = loss_tensor.sum() / self.n_workers
        return avg_loss
    
    def refit(self, Xs, ys, X_val=None, y_val=None, print_freq=100):
        """
        Perform multiple rounds of training after initialization
        """
        # Set all models to training mode
        for model in self.models_:
            model.train()
        for epoch in range(self.epochs):
            running_loss = 0.0  # Accumulated loss (on GPU)
            start_time = time.time()
            loss_ = self._update(Xs, ys)
            running_loss += loss_

            if (epoch + 1) % print_freq == 0:
                if X_val is not None:
                    accs, losss = self.evaluate(X_val, y_val)
                    self.history_['acc'].append(accs); self.history_['loss'].append(losss)
                loss_val = (running_loss / print_freq).item()
                t_elapsed = time.time() - start_time
                remaining = t_elapsed * (self.epochs - epoch - 1)
                print(f"\rEpoch [{epoch+1}/{self.epochs}], Avg Loss: {loss_val:.4f}, TC: {t_elapsed:3.2f}s, ETA: {remaining/60:3.2f}min", end='')
        return self
    
    def get_parameters(self):
        return [model.get_parameters() for model in self.models_]
            
    def set_parameters(self, new_params):
        for model, params in zip(self.models_, new_params):
            model.set_parameters(params)
    
    def predict_proba(self, X):
        return [model.predict_proba(X) for model in self.models_]
    
    def predict(self, X):
        return [model.predict(X) for model in self.models_]
    
    def evaluate(self, X_val, y_val):
        accs = []; losss = []
        for model in self.models_:
            acc, loss = model.evaluate(X_val, y_val)
            accs.append(acc); losss.append(loss)
        return accs, losss
    
    def save_history(self, save_path='.'):
        df_loss = pd.DataFrame(self.history_['loss']).T
        df_acc = pd.DataFrame(self.history_['acc']).T
        df_loss.to_csv(f"{save_path}_valloss.csv", index=False)
        df_acc.to_csv(f"{save_path}_valacc.csv", index=False)
        return df_acc, df_loss


class DFLOptimizerInit(DFLOptimizerBase):
    def __init__(self, neighbors, lr_constant, 
                 model_type='lenet5', 
                 n_classes=10, n_workers=10,
                 epochs=10, device='cpu', random_state=None,
                 input_dim=None,
                 pretrained=False,
                 custom_init=True):
        """
        This class sets lr_constant as the learning rate for all workers
        """
        super().__init__(
            neighbors=neighbors, 
            lr=lr_constant,  # All workers use the same learning rate
            model_type=model_type,
            n_classes=n_classes,
            n_workers=n_workers,
            epochs=epochs,
            device=device,
            random_state=random_state,
            input_dim=input_dim,
            pretrained=pretrained,
            custom_init=custom_init)
        self.lr_constant = lr_constant       


class DFLOptimizer(DFLOptimizerInit):
    def __init__(self, neighbors, lr_constant,
                 model_type='lenet5',
                 n_classes=10, n_workers=10,
                 epochs=10, device='cpu', random_state=None,
                 input_dim=None,
                 pretrained=False,
                 custom_init=True,
                 cn=5):
        super().__init__(
            neighbors=neighbors,
            lr_constant=lr_constant,
            model_type=model_type,
            n_classes=n_classes,
            n_workers=n_workers,
            epochs=epochs,
            device=device,
            random_state=random_state,
            input_dim=input_dim,
            pretrained=pretrained,
            custom_init=custom_init
        )
        self.cn = cn
        
    def _compute_adaptive_lr(self, Xs, ys):   
        """
        Compute gradients first, then adaptively adjust learning rates based on gradient magnitude
        """
        grad = []
        for model, X, y in zip(self.models_, Xs, ys):
            grad_m = model.compute_gradients(X, y)
            grad.append(torch.cat([g.view(-1) for g in grad_m], dim=0))
        # Compute and normalize
        weights = np.array([process_weights(row, self.cn) for row in grad])
        weights = weights / np.max(weights)
        self.max_weights = np.max(weights)
        adaptive_lr = self.lr_constant * weights
        self.lr = adaptive_lr
        self.set_learning_rates(adaptive_lr)
        return self
    
    def adapt_refit(self, Xs, ys, print_freq=100):
        """
        Dynamically compute gradients and update learning rates, then perform multiple rounds of DFL training
        """
        self._compute_adaptive_lr(Xs, ys)
        self.refit(Xs, ys, print_freq=print_freq)
        return self
    
    def refit_loss(self, X_val, y_val, print_freq=100):
        self.weights = self.lr / self.lr_constant
        losses = []
        for m, (Xm, ym) in enumerate(zip(X_val, y_val)):
            self.models_[m].train()
            loss = self.models_[m].compute_loss(Xm, ym)
            losses.append(loss.item())
        self.loss_ = self.weights * np.array(losses)
        self.weights_mean = self.weights * 1.0
        self.weights_squared = (self.weights ** 2) * 1.0
        self.n_iter_loss = 0

        self.A = get_adj_from_neib(self.neighbors, M=self.n_workers)
        self.W_self = self.A / self.A.sum(axis=1, keepdims=True)  # Include self-loss to avoid oscillations
        for n_iter in range(self.epochs):
            prev_loss = self.loss_ * 1.0; prev_w_mean = self.weights_mean * 1.0
            prev_w_squared = self.weights_squared * 1.0
            self.loss_ = self.W_self @ prev_loss
            self.weights_mean = self.W_self @ prev_w_mean
            self.weights_squared = self.W_self @ prev_w_squared
            dist = np.max(np.abs(prev_loss - self.loss_))
            self.n_iter_loss += 1  
            if dist < 1e-8:
                break
        nn_samples = int(len(X_val) / self.n_workers)
        weights_squared = self.weights_squared / (self.weights_mean ** 2)
        ub = 1.64 * np.sqrt(weights_squared / nn_samples)
        self.loss_ /= self.weights_mean
        self.loss_ += ub

class DFLOptimizerBasePara:
    def __init__(self, neighbors, lr, 
                 n_workers=10,
                 model_type='lenet5',  # default model is lenet5
                 n_classes=10,
                 epochs=10,
                 device='cpu', 
                 random_state=None,
                 param_path='.',
                 pretrained=True,
                 custom_init=True,
                idx=0):
        """
        - neighbors: list of lists, each worker's neighbor list
        - lr: list or scalar, learning rate for each worker
        """
        self.neighbors = neighbors
        self.lr = lr
        self.n_workers = n_workers
        self.model_type = model_type
        self.n_classes = n_classes
        self.epochs = epochs
        self.device = device
        self.random_state = random_state
        self.pretrained = pretrained
        self.custom_init = custom_init
        self.param_path = param_path  # path to save model parameters
        self.idx = idx  # index of the current worker
        self.n_iter = 0  # iteration counter
        
    def _initialize_models(self):
        """
        Xs and ys are the data for each worker
        """
        # If self.lr is a scalar, not a list, then all workers use the same learning rate
        if isinstance(self.lr, (float, int)):
            lr_list = [self.lr] * self.n_workers
        else:
            lr_list = self.lr
        idx = self.idx
        self.model_ = Optimizer(model_type=self.model_type,
            num_classes=self.n_classes,
            lr=lr_list[idx], device=self.device,
            random_state=self.random_state)
        return self
    
    def set_learning_rates(self, new_lr):
        """
        Update the learning rate for each worker
        """
        idx = self.idx
        self.model_.set_learning_rates(new_lr[idx])
        return self
    
    def fit(self, dataloader, print_freq=1):
        self._initialize_models()
        self.refit(dataloader, print_freq=print_freq)
        return self

    def _aggregate(self, m=0):
        """
        Aggregate parameters from neighbors by averaging
        """
        param_lists = [download_parameters(param_path=self.param_path,
                            m=i, n_iter=self.n_iter) for i in self.neighbors[m]]
        averaged_params = [
            torch.mean(torch.stack(params, dim=0), dim=0) 
            for params in zip(*param_lists)
        ]
        return averaged_params
        
    def _update(self, Xm, ym):
        """
        Perform one DFL update (first aggregate, then fit one step locally)
        """
        loss_vals = []
        idx = self.idx
        # 1) Aggregate from neighbors
        if self.n_iter > 0:
            for m in self.neighbors[idx]:
                check_param_path(param_path=self.param_path, m=m, n_iter=self.n_iter)
            averaged_params = self._aggregate(m=idx)
            self.model_.set_parameters(averaged_params)
        # 2) Local update
        loss_val = self.model_.fit_onestep(Xm, ym)
        param_m = self.model_.get_parameters()
        self.n_iter += 1
        upload_parameters(param_m, param_path=self.param_path, m=idx, n_iter=self.n_iter)
        return loss_val
    
    def refit(self, dataloader, print_freq=100):
        """
        Can be used for repeated DFL training after initialization
        """
        # Set the model to training mode
        self.model_.train()
        n_batch = len(dataloader)
        for epoch in range(self.epochs):
            running_loss = 0.0  # Accumulated loss on GPU
            start_time = time.time()
            for (Xm, ym) in dataloader:
                Xm = Xm.to(self.device)
                ym = ym.to(self.device)
                loss_ = self._update(Xm, ym)
                running_loss += loss_
            if (epoch+1) % print_freq == 0:
                loss_val = (running_loss / (n_batch * print_freq)).item()  # Synchronize
                t_elapsed = time.time() - start_time
                remaining = t_elapsed * (self.epochs - epoch - 1)
                print(f"\rEpoch [{epoch+1}/{self.epochs}], Avg Loss: {loss_val:.4f}, TC: {t_elapsed:3.2f}s, ETA: {remaining/60:3.2f}min", end='')
        return self


class DFLOptimizerInitPara(DFLOptimizerBasePara):
    def __init__(self, neighbors, lr_constant, 
                 n_workers=10, model_type='lenet5',  # default model is lenet5
                 n_classes=10, epochs=10,
                 device='cpu', random_state=None,
                 param_path='.',
                 pretrained=True,
                 custom_init=True,
                 idx=0):
        """
        In this class, lr_constant is used as the learning rate for all workers
        """        
        super().__init__(
            neighbors=neighbors, 
            lr=lr_constant,  # all workers use the same learning rate
            n_workers=n_workers,
            model_type=model_type,
            n_classes=n_classes, epochs=epochs,
            device=device, random_state=random_state,
            param_path=param_path,
            pretrained=pretrained,
            custom_init=custom_init,
            idx=idx)
        self.lr_constant = lr_constant





