import torch
import torch.optim as optim
import torch.nn as nn
import copy
from torch.optim.lr_scheduler import StepLR


class client_FedCIL():
    def __init__(self, model, train_dataset, name, args, **kwargs):
        self.args = args
        self.model = copy.deepcopy(model)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.lr_local, momentum=0.9)
        self.scheduler = StepLR(self.optimizer, step_size=5, gamma=0.96)
        self.theta_reg = None
        self.init_weights = None
        self.initialize_model()
        self.name = name

        self.generator = kwargs.get('g', None)
        self.available_labels = None
        self.available_labels_current = None
        self.last_generator = None


    def set_theta_reg(self, theta_reg):
        self.theta_reg = theta_reg


    def _random_init_weights(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            torch.nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

    def initialize_model(self):
        if self.init_weights is not None:
            self.model.load_state_dict(self.init_weights)
        else:
            self.model.apply(self._random_init_weights)

    def update_model(self, server_generator):
        self.generator = copy.deepcopy(server_generator)

    def train(self, local_round, generator_server, classes_so_far):
        generator_server.eval()

        glob_iter_ = local_round
        device = self.device
        self.generator.generator.to(device).train()
        self.generator.critic.to(device).train()
        iter_trainloader = iter(self.train_loader)
        for epoch in range(1):
            for x, y in self.train_loader:
                x, y = x.to(self.device), y.to(self.device)
                if self.last_generator is not None:
                    x_, y_ = self.last_generator.sample(self.args.batch_size,
                                                        classes_so_far)
                    x_ = x_.to(device)
                    y_ = y_.to(device)
                else:
                    x_ = y_ = None

                result = self.generator.train_a_batch_all(
                    available_labels=self.available_labels, generator_server=generator_server,
                    glob_iter_=glob_iter_, x=x, y=y, x_=x_, y_=y_,
                    classes_so_far=classes_so_far
                )








