import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from optimizers.dv_optimizer import DVOptimizer


class GradientBase:
    def __init__(self, datas: torch.utils.data.TensorDataset, test_data: torch.utils.data.TensorDataset, model: nn.Module, args: dict):
        self.args = args
        self.use_momentum = args.use_momentum
        self.lr = args.lr
        self.random_seed = args.random_seed
        self.warmup_ratio = args.warmup_ratio
        self.datas = datas
        self.num_players = args.num_players
        self.model = model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        self.malicious_players = None
        # if args has attribute dataset_size, then use it, otherwise use 1
        self.dataset_size = args.dataset_size if hasattr(args, 'dataset_size') else 1
        # split test data into inputs and labels
        loader = torch.utils.data.DataLoader(test_data, batch_size=len(test_data))
        inputs, labels = next(iter(loader))
        self.test_data = inputs.to(self.device)
        self.test_labels = labels.to(self.device)
        # initialize model with gaussian random weights, fix the random seed
        torch.manual_seed(0)

        if self.lr is None:
            self.lr = self.find_lr()


    def find_lr(self):
        # do one pass of the model
        best_loss = 9999999
        best_lr = 0
        for i in np.arange(0, 3, 0.3):
            model = copy.deepcopy(self.model)
            model = model.apply(self._init_weights)
            lr = 10 ** (-i)
            optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0, weight_decay=0)
            data_loader = torch.utils.data.DataLoader(self.datas, batch_size=self.dataset_size)
            for player_data, labels in data_loader:
                # zero out model gradients
                optimizer.zero_grad()
                # send player data to gpu if available
                player_data = player_data.to(self.device)
                labels = labels.to(self.device)
                # train model on player data
                model_pred = model(player_data)
                model_loss = F.cross_entropy(model_pred, labels)
                # get model gradients
                model_loss.backward()
                # update model
                optimizer.step()
            # test model on test data
            with torch.no_grad():
                model_pred = model(self.test_data)
                model_loss = F.cross_entropy(model_pred, self.test_labels)
                if model_loss < best_loss:
                    best_loss = model_loss
                    best_lr = lr
        print(f"Found best lr: {best_lr}")
        return best_lr

    def get_optimizer(self, model: nn.Module, momentum: float = 0.9, weight_decay: float = 0, num_iters: int = 0) -> torch.optim.Optimizer:
        if self.use_momentum:
            return DVOptimizer(model.parameters(), lr=self.lr, momentum=momentum, num_iters=num_iters, warmup_epochs=(self.warmup_ratio * num_iters))
        else:
            return torch.optim.SGD(model.parameters(), lr=self.lr, momentum=0, weight_decay=weight_decay)
    
    @torch.no_grad()
    def _init_weights(self, m) -> None:
        if isinstance(m, nn.Conv2d):
            if self.args.model.startswith("resnet"):
                return
            torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
        elif isinstance(m, nn.BatchNorm2d):
            if self.args.model.startswith("resnet"):
                return
            torch.nn.init.ones_(m.weight)
            torch.nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)
            torch.nn.init.zeros_(m.bias)

    def run(self, num_iters: int, clipping_norm: float = 0, epsilon: float = 20, delta: float = 1e-5):
        raise NotImplementedError
    
    def run_iter(self, step: int, model: nn.Module, optimizer: torch.optim.Optimizer, permutation: np.array):
        raise NotImplementedError
