import torch
import torch.optim as optim
import torch.nn as nn
import copy
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR


class client_FedProx():
    def __init__(self, model, train_dataset, name, args):
        self.args = args
        self.model = copy.deepcopy(model)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.lr_local, momentum=0.9)
        self.scheduler = StepLR(self.optimizer, step_size=5, gamma=0.96)
        self.theta_reg = None 
        self.init_weights = None
        self.initialize_model()
        self.name = name

    def set_theta_reg(self, theta_reg):
        self.theta_reg = theta_reg


    def _random_init_weights(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            torch.nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

    def initialize_model(self):
        if self.init_weights is not None:
            self.model.load_state_dict(self.init_weights)
        else:
            self.model.apply(self._random_init_weights)
    def update_model(self, new_model):
        self.model.load_state_dict(new_model.state_dict())

    def train(self, theta_reg=None, first_round=False, global_model=None):
        self.model = self.model.to(self.device)
        self.model.train()
        for epoch in range(self.args.epochs):
            for images, labels in self.train_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                base_loss = nn.CrossEntropyLoss()
                proximal_term = 0.0
                for theta, theta_t in zip(self.model.parameters(), global_model):
                    proximal_term += (theta - theta_t).norm(2)
                loss = base_loss(outputs, labels) + self.args.lambda_reg / 10 * proximal_term / 2
                loss.backward()
                self.optimizer.step()
        self.scheduler.step()



