import time
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import models  # Required for using ResNet18
from util_real import process_weights  # Your utility function
from mobilenetv2 import MobileNetV2
from util_real import upload_parameters, download_parameters, check_param_path, get_adj_from_neib
import itertools
import pandas as pd
 
def grid_search(Xs, ys, X_val, y_val, base_model, init_parameters,
                param_grid, refit=True):
    """
    Returns best parameters and the model refitted with best parameters (if refit=True).
    """
    best_loss = np.inf
    losses = []
    scores = []

    for params in itertools.product(*param_grid.values()):
        param_dict = dict(zip(param_grid.keys(), params))
        # Dynamically assign parameters to the model
        for key, value in param_dict.items():
            setattr(base_model, key, value)
        base_model.set_parameters(init_parameters)
        base_model.adapt_refit(Xs, ys)

        _ = base_model.refit_loss(X_val, y_val)
        loss = base_model.loss_.squeeze()

        losses.append(loss)
        if np.min(loss) < best_loss:
            best_loss = np.min(loss)
            best_params = param_dict.copy()
            best_weights = base_model.weights.copy()

    if refit:
        for key, value in best_params.items():
            setattr(base_model, key, value)
    return losses, best_params, best_weights, base_model


def split_bucket_indices(Xs_all, ys_all, val_ratio=0.2, random_state=42):
    """
    Splits sample indices in each bucket into train / validation sets.

    Args:
    - Xs_all: list of feature tensors (for each bucket)
    - val_ratio: proportion of data for validation
    - random_state: random seed

    Returns:
    - train and validation splits for Xs and ys (as lists)
    """
    np.random.seed(random_state)
    Xs_train_all, ys_train_all = [], []
    Xs_valid_all, ys_valid_all = [], []

    for X, y in zip(Xs_all, ys_all):
        n = len(X)
        indices = np.arange(n)
        np.random.shuffle(indices)
        split = int(n * (1 - val_ratio))
        train_idx = indices[:split]
        valid_idx = indices[split:]
        Xs_train_all.append(X[train_idx])
        Xs_valid_all.append(X[valid_idx])
        ys_train_all.append(y[train_idx])
        ys_valid_all.append(y[valid_idx])

    return torch.stack(Xs_train_all), torch.stack(ys_train_all), torch.stack(Xs_valid_all), torch.stack(ys_valid_all)


class LogisticRegressionModel(nn.Module):
    """
    Simple Logistic Regression Model
    Input shape: [batch_size, input_dim]
    """
    def __init__(self, input_dim, num_classes=10):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.linear(x)


class LeNet5(nn.Module):
    """
    LeNet5 Model (suitable for MNIST input [batch_size, 1, 28, 28])
    """
    def __init__(self, n_classes=10):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, n_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class TwoLayerNet(nn.Module):
    def __init__(self, input_dim=512, num_classes=10):
        super(TwoLayerNet, self).__init__()
        hidden_dim = 256
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        return self.net(x)


def build_model(model_type, num_classes=10, input_dim=None, pretrained=False):
    """
    Build model by type.
    Args:
    - model_type: 'logistic', 'lenet5', 'resnet18', ...
    - num_classes: number of classes
    - input_dim: required for logistic regression
    - pretrained: use pretrained weights for resnet18
    """
    if model_type.lower() == 'logistic':
        if input_dim is None:
            raise ValueError("Input_dim must be specified for logistic regression.")
        return LogisticRegressionModel(input_dim=input_dim, num_classes=num_classes)
    
    elif model_type.lower() == 'lenet5':
        return LeNet5(n_classes=num_classes)

    elif model_type.lower() == 'resnet18':
        model = models.resnet18(pretrained=pretrained)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
        return model
    
    elif model_type.lower() == 'nn':
        return TwoLayerNet(input_dim=input_dim, num_classes=num_classes)
    
    else:
        raise ValueError(f"Unknown model_type: {model_type}")


def init_weights(m):
    """
    Custom weight initialization (Xavier for Linear/Conv2d, bias=0)
    """
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)


class Optimizer(nn.Module):
    def __init__(self, model_type='logistic',
                 num_classes=10,
                 lr=0.01,
                 device='cpu',
                 random_state=None,
                 input_dim=None,
                 pretrained=False,
                 custom_init=True,
                 momentum=0.0):
        super(Optimizer, self).__init__()

        if random_state is not None:
            np.random.seed(random_state)
            torch.manual_seed(random_state)
            if device == 'cuda':
                torch.cuda.manual_seed_all(random_state)

        self.net = build_model(model_type, num_classes, input_dim, pretrained)

        if custom_init and not pretrained:
            self.net.apply(init_weights)

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(self.parameters(), lr=lr, momentum=momentum)

        self.device = device
        self.to(device)

    def forward(self, x):
        return self.net(x)

    def get_parameters(self):
        return [p.detach().clone() for p in self.parameters()]

    def set_parameters(self, new_params):
        with torch.no_grad():
            for p, new_p in zip(self.parameters(), new_params):
                p.copy_(new_p)

    def set_learning_rates(self, new_lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = new_lr

    def fit_onestep(self, X, Y):
        loss = self.compute_loss(X, Y)
        loss.backward()
        self.optimizer.step()
        return loss.detach()

    def compute_gradients(self, X, Y):
        loss = self.compute_loss(X, Y)
        loss.backward()
        return [p.grad for p in self.parameters()]

    def compute_loss(self, X, Y):
        self.optimizer.zero_grad()
        logits = self.forward(X)
        loss = self.criterion(logits, Y)
        return loss

    def compute_gradients_loss(self, X, Y):
        loss = self.compute_loss(X, Y)
        loss.backward()
        return [p.grad for p in self.parameters()], loss.detach()

    def fit(self, X, Y, epochs=10, print_freq=100):
        self.train()
        for epoch in range(epochs):
            running_loss = 0.0
            start_time = time.time()
            loss_val = self.fit_onestep(X, Y)
            running_loss += loss_val
            if (epoch + 1) % print_freq == 0:
                avg_loss = (running_loss / print_freq).item()
                t_elapsed = time.time() - start_time
                print(f"\rEpoch [{epoch + 1}/{epochs}], Avg Loss: {avg_loss:.4f}, TC: {t_elapsed:3.2f}s", end='')

    def predict_proba(self, X):
        self.eval()
        with torch.no_grad():
            logits = self.forward(X)
            proba = torch.softmax(logits, dim=1)
        return proba.cpu().numpy()

    def predict(self, X):
        proba = self.predict_proba(X)
        pred = np.argmax(proba, axis=1)
        return pred

    def evaluate(self, X_test, y_test):
        with torch.no_grad():
            outputs = self.forward(X_test)
            val = F.cross_entropy(outputs, y_test).item()
            preds = outputs.argmax(dim=1)
            correct = (preds == y_test).sum().item()
            acc = correct / len(y_test)
        return acc, val

    def fit_batch(self, dataloader, epochs=10, print_freq=1):
        self.train()
        n_batch = len(dataloader)
        for epoch in range(epochs):
            running_loss = 0.0
            start_time = time.time()
            for (Xm, Ym) in dataloader:
                Xm, Ym = Xm.to(self.device), Ym.to(self.device)
                loss_val = self.fit_onestep(Xm, Ym)
                running_loss += loss_val
            if (epoch + 1) % print_freq == 0:
                loss_val = (running_loss / (n_batch * print_freq)).item()
                t_elapsed = time.time() - start_time
                remaining = t_elapsed * (epochs - epoch - 1)
                print(f"\rEpoch [{epoch + 1}/{epochs}], Avg Loss: {loss_val:.4f}, TC: {t_elapsed:3.2f}s, ETA: {remaining / 60:3.2f}min", end='')

    def predproba_batch(self, dataloader):
        self.eval()
        all_proba = []
        with torch.no_grad():
            for Xm, _ in dataloader:
                Xm = Xm.to(self.device)
                logits = self.forward(Xm)
                proba = torch.softmax(logits, dim=1)
                all_proba.append(proba.cpu().numpy())
        return np.concatenate(all_proba, axis=0)
    def pred_batch(self, dataloader):
        proba = self.predproba_batch(dataloader)
        return np.argmax(proba, axis=1)
    
class DFLOptimizerBase:
    def __init__(self, neighbors, lr, 
                 model_type='logistic',  # Default to logistic; other options: lenet5, resnet18
                 n_classes=10,
                 n_workers=10,
                 epochs=10,
                 device='cpu', 
                 random_state=None,
                 input_dim=None,      # Required if using logistic regression
                 pretrained=False,    # If using resnet18, you can choose pretrained
                 custom_init=True     # Whether to use custom initialization
                 ):
        """
        - neighbors: list of list, neighbor list for each worker
        - lr: list or scalar, learning rate for each worker
        - model_type: 'logistic', 'lenet5', 'resnet18'
        """
        self.neighbors = neighbors
        self.lr = lr
        self.model_type = model_type
        self.n_classes = n_classes
        self.n_workers = n_workers
        self.epochs = epochs
        self.device = device
        self.random_state = random_state
        self.input_dim = input_dim
        self.pretrained = pretrained
        self.custom_init = custom_init
        
    def _initialize_models(self):
        # If self.lr is a scalar instead of a list, use the same learning rate for all workers
        if isinstance(self.lr, (float, int)):
            lr_list = [self.lr] * self.n_workers
        else:
            lr_list = self.lr
        self.models_ = []
        for m in range(self.n_workers):
            model = Optimizer(
                model_type=self.model_type,
                num_classes=self.n_classes,
                lr=lr_list[m],
                device=self.device,
                random_state=self.random_state,
                input_dim=self.input_dim,
                pretrained=self.pretrained,
                custom_init=self.custom_init
            )
            self.models_.append(model)
        return self
    
    def _initialize_history(self):
        self.history_ = {'loss': [], 'acc': []}
    
    def set_learning_rates(self, new_lr):
        """
        Update the learning rate of each worker
        """
        for model, lr in zip(self.models_, new_lr):
            model.set_learning_rates(lr)
        return self
    
    def fit(self, Xs, ys, X_val=None, y_val=None, print_freq=100):
        self._initialize_models()
        self._initialize_history()
        self.refit(Xs, ys, X_val=X_val, y_val=y_val, print_freq=print_freq)
        return self
    
    def _aggregate(self, m=0):
        """
        Average the parameters of the neighbors' models
        """
        param_lists = [self.param_lists[i] for i in self.neighbors[m]]
        averaged_params = [
            torch.mean(torch.stack(params, dim=0), dim=0) 
            for params in zip(*param_lists)
        ]
        return averaged_params
        
    def _update(self, Xs, ys):
        """
        Perform one DFL update for all workers (aggregate first, then fit_onestep)
        """
        loss_vals = []
        self.param_lists = self.get_parameters()
        for m, (Xm, ym) in enumerate(zip(Xs, ys)):
            # 1) Aggregate from neighbors
            averaged_params = self._aggregate(m=m)
            self.models_[m].set_parameters(averaged_params)
            # 2) Local update
            loss_val = self.models_[m].fit_onestep(Xm, ym)
            loss_vals.append(loss_val)

        # Convert list to tensor and compute weighted sum using broadcasting
        loss_tensor = torch.stack(loss_vals)  # shape: [n_workers]
        avg_loss = loss_tensor.sum() / self.n_workers
        return avg_loss
    
    def refit(self, Xs, ys, X_val=None, y_val=None, print_freq=100):
        """
        Re-run multiple rounds of updates after initialization
        """
        # Set all models to training mode
        for model in self.models_:
            model.train()
        for epoch in range(self.epochs):
            running_loss = 0.0  # Accumulate loss on GPU
            start_time = time.time()
            loss_ = self._update(Xs, ys)
            running_loss += loss_

            if (epoch + 1) % print_freq == 0:
                if X_val is not None:
                    accs, losss = self.evaluate(X_val, y_val)
                    self.history_['acc'].append(accs)
                    self.history_['loss'].append(losss)
                loss_val = (running_loss / print_freq).item()  # Sync once
                t_elapsed = time.time() - start_time
                remaining = t_elapsed * (self.epochs - epoch - 1)
                print(f"\rEpoch [{epoch+1}/{self.epochs}], Avg Loss: {loss_val:.4f}, TC: {t_elapsed:3.2f}s, ETA: {remaining/60:3.2f}min", end='')
        return self
    
    def get_parameters(self):
        return [model.get_parameters() for model in self.models_]
            
    def set_parameters(self, new_params):
        for model, params in zip(self.models_, new_params):
            model.set_parameters(params)
    
    def predict_proba(self, X):
        return [model.predict_proba(X) for model in self.models_]
    
    def predict(self, X):
        return [model.predict(X) for model in self.models_]
    
    def evaluate(self, X_val, y_val):
        accs = []
        losss = []
        for model in self.models_:
            acc, loss = model.evaluate(X_val, y_val)
            accs.append(acc)
            losss.append(loss)
        return accs, losss
    
    def save_history(self, save_path='.'):
        df_loss = pd.DataFrame(self.history_['loss']).T
        df_acc = pd.DataFrame(self.history_['acc']).T
        df_loss.to_csv(f"{save_path}_valloss.csv", index=False)
        df_acc.to_csv(f"{save_path}_valacc.csv", index=False)
        return df_acc, df_loss


class DFLOptimizerInit(DFLOptimizerBase):
    def __init__(self, neighbors, lr_constant, 
                 model_type='lenet5', 
                 n_classes=10, n_workers=10,
                 epochs=10, device='cpu', random_state=None,
                 input_dim=None,
                 pretrained=False,
                 custom_init=True):
        """
        In this class, lr_constant is used as the learning rate for all workers
        """
        super().__init__(
            neighbors=neighbors, 
            lr=lr_constant,  # All workers use the same learning rate
            model_type=model_type,
            n_classes=n_classes,
            n_workers=n_workers,
            epochs=epochs,
            device=device,
            random_state=random_state,
            input_dim=input_dim,
            pretrained=pretrained,
            custom_init=custom_init)
        self.lr_constant = lr_constant

class DFLOptimizer(DFLOptimizerInit):
    def __init__(self, neighbors, lr_constant,
                 model_type='lenet5',
                 n_classes=10, n_workers=10,
                 epochs=10, device='cpu', random_state=None,
                 input_dim=None,
                 pretrained=False,
                 custom_init=True,
                 cn=5):
        super().__init__(
            neighbors=neighbors,
            lr_constant=lr_constant,
            model_type=model_type,
            n_classes=n_classes,
            n_workers=n_workers,
            epochs=epochs,
            device=device,
            random_state=random_state,
            input_dim=input_dim,
            pretrained=pretrained,
            custom_init=custom_init
        )
        self.cn = cn
        
    def _compute_adaptive_lr(self, Xs, ys):   
        """
        First obtain gradients, then adaptively adjust learning rates based on gradient magnitudes.
        """
        grad = []
        for model, X, y in zip(self.models_, Xs, ys):
            grad_m = model.compute_gradients(X, y)
            grad.append(torch.cat([g.view(-1) for g in grad_m], dim=0))
        # Compute and normalize
        weights = np.array([process_weights(row, self.cn) for row in grad])
        weights = weights / np.max(weights)
        self.max_weights = np.max(weights)
        adaptive_lr = self.lr_constant * weights
        self.lr = adaptive_lr
        self.set_learning_rates(adaptive_lr)
        return self
    
    def adapt_refit(self, Xs, ys, print_freq=100):
        """
        Dynamically compute gradients and update learning rates,
        then perform multiple rounds of DFL training.
        """
        self._compute_adaptive_lr(Xs, ys)
        self.refit(Xs, ys, print_freq=print_freq)
        return self
    
    def refit_loss(self, X_val, y_val, print_freq=100):
        self.weights = self.lr / self.lr_constant
        losses = []
        for m, (Xm, ym) in enumerate(zip(X_val, y_val)):
            self.models_[m].train()
            loss = self.models_[m].compute_loss(Xm, ym)
            losses.append(loss.item())
        self.loss_ = self.weights * np.array(losses)
        self.weights_mean = self.weights * 1.0
        self.weights_squared = (self.weights ** 2) * 1.0
        self.n_iter_loss = 0

        self.A = get_adj_from_neib(self.neighbors, M=self.n_workers)
        # Also consider each node's own loss to avoid periodicity
        self.W_self = self.A / self.A.sum(axis=1, keepdims=True)
        
        for n_iter in range(self.epochs):
            prev_loss = self.loss_ * 1.0
            prev_w_mean = self.weights_mean * 1.0
            prev_w_squared = self.weights_squared * 1.0
            self.loss_ = self.W_self @ prev_loss
            self.weights_mean = self.W_self @ prev_w_mean
            self.weights_squared = self.W_self @ prev_w_squared
            dist = np.max(np.abs(prev_loss - self.loss_))
            self.n_iter_loss += 1  
            if dist < 1e-8:
                break
        
        nn_samples = int(X_val.shape[0] / self.n_workers)
        weights_squared = self.weights_squared / (self.weights_mean ** 2)
        ub = 1.64 * np.sqrt(weights_squared / nn_samples)
        self.loss_ /= self.weights_mean
        self.loss_ += ub

class DFLOptimizerBasePara:
    def __init__(self, neighbors, lr, 
                 n_workers=10,
                 model_type='lenet5',  # Default to using lenet5
                 n_classes=10,
                 epochs=10,
                 device='cpu', 
                 random_state=None,
                 param_path='.',
                 pretrained=True,
                 custom_init=True,
                 idx=0):
        """
        - neighbors: list of lists, the list of neighbors for each worker
        - lr: list or scalar, the learning rate for each worker
        """
        self.neighbors = neighbors
        self.lr = lr
        self.n_workers = n_workers
        self.model_type = model_type
        self.n_classes = n_classes
        self.epochs = epochs
        self.device = device
        self.random_state = random_state
        self.pretrained = pretrained
        self.custom_init = custom_init
        self.param_path = param_path  # Path where parameters are stored
        self.idx = idx  # The index of the current worker
        self.n_iter = 0  # Number of iterations
        
    def _initialize_models(self):
        """
        Xs and ys are the data of each worker
        """
        # If self.lr is a scalar (not a list), all workers use the same learning rate
        if isinstance(self.lr, (float, int)):
            lr_list = [self.lr] * self.n_workers
        else:
            lr_list = self.lr
        idx = self.idx
        self.model_ = Optimizer(model_type=self.model_type,
                                num_classes=self.n_classes,
                                lr=lr_list[idx],
                                device=self.device,
                                random_state=self.random_state)
        return self
    
    def set_learning_rates(self, new_lr):
        """
        Update the learning rate of each worker
        """
        idx = self.idx
        self.model_.set_learning_rates(new_lr[idx])
        return self
    
    def fit(self, dataloader, print_freq=1):
        self._initialize_models()
        self.refit(dataloader, print_freq=print_freq)
        return self

    def _aggregate(self, m=0):
        """
        Average the parameters from the models of neighboring workers
        """
        param_lists = [download_parameters(param_path=self.param_path,
                                           m=i, n_iter=self.n_iter) for i in self.neighbors[m]]
        averaged_params = [
            torch.mean(torch.stack(params, dim=0), dim=0) 
            for params in zip(*param_lists)
        ]
        return averaged_params
        
    def _update(self, Xm, ym):
        """
        Perform one DFL update (aggregate first, then fit_onestep)
        """
        loss_vals = []
        idx = self.idx

        # 1) Aggregate from neighbors
        if self.n_iter > 0:
            for m in self.neighbors[idx]:
                check_param_path(param_path=self.param_path, m=m, n_iter=self.n_iter)
            averaged_params = self._aggregate(m=idx)
            self.model_.set_parameters(averaged_params)

        # 2) Local update        
        loss_val = self.model_.fit_onestep(Xm, ym)
        param_m = self.model_.get_parameters()
        self.n_iter += 1
        upload_parameters(param_m, param_path=self.param_path, m=idx, n_iter=self.n_iter)
        return loss_val
    
    def refit(self, dataloader, print_freq=100):
        """
        After initialization, multiple rounds of updates can be performed
        """
        # Set the model to training mode
        self.model_.train()
        n_batch = len(dataloader)
        for epoch in range(self.epochs):
            running_loss = 0.0  # Accumulated loss on GPU
            start_time = time.time()
            for (Xm, ym) in dataloader:
                Xm = Xm.to(self.device)
                ym = ym.to(self.device)
                loss_ = self._update(Xm, ym)
                running_loss += loss_
            if (epoch + 1) % print_freq == 0:
                loss_val = (running_loss / (n_batch * print_freq)).item()  # Synchronize once
                t_elapsed = time.time() - start_time
                remaining = t_elapsed * (self.epochs - epoch - 1)
                print(f"\rEpoch [{epoch+1}/{self.epochs}], Avg Loss: {loss_val:.4f}, TC: {t_elapsed:3.2f}s, ETA: {remaining/60:3.2f}min", end='')
        return self

class DFLOptimizerInitPara(DFLOptimizerBasePara):
    def __init__(self, neighbors, lr_constant, 
                 n_workers=10, model_type='lenet5',  # Default to using lenet5
                 n_classes=10, epochs=10,
                 device='cpu', random_state=None,
                 param_path='.',
                 pretrained=True,
                 custom_init=True,
                 idx=0):
        """
        In this class, lr_constant is used as the learning rate for all workers
        """        
        super().__init__(
            neighbors=neighbors, 
            lr=lr_constant,  # All workers use the same learning rate
            n_workers=n_workers,
            model_type=model_type,
            n_classes=n_classes, epochs=epochs,
            device=device, random_state=random_state,
            param_path=param_path,
            pretrained=pretrained,
            custom_init=custom_init,
            idx=idx)
        self.lr_constant = lr_constant


