import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
np.set_printoptions(edgeitems=30)
torch.set_printoptions(edgeitems=30)

TOP_K=1

class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return torch.as_tensor(image), torch.as_tensor(label)

    def add_data(self, idxs: list):
        self.idxs.extend(idxs)


class Client:
    def __init__(self, device, local_model, train_dataset, test_dataset, train_idxs, args, logger=None):
        self.device = device
        self.args = args
        self.local_model = local_model
        self.logger = logger
        self.data_num = len(train_idxs)
        self.origin_data_num = len(train_idxs)
        self.trainingLoss = None
        self.testingLoss = None
        self.testingAcc = None
        self.test_dataset = test_dataset
        self.trainloader, self.validloader, self.trainloader_full = self.train_val_test(
            train_dataset, list(train_idxs))
        self.optimizer = torch.optim.SGD(local_model.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        self.criterion = torch.nn.CrossEntropyLoss().to(self.device)
        


    def train_val_test(self, train_dataset, train_idxs):
        """
        Returcns train, validation and test dataloaders for a given dataset
        and user indexes.
        """
        trainloader = DataLoader(DatasetSplit(train_dataset, train_idxs),
                                 batch_size=self.args.local_bs, shuffle=True, num_workers=8)

        validloader = None
        trainloader_full = DataLoader(DatasetSplit(train_dataset, train_idxs), batch_size=len(train_idxs), shuffle=False)
        return trainloader, validloader, trainloader_full


    def train(self, epoch, global_model):
        self.local_model.train()
        epoch_loss = []
        local_ep = self.args.local_ep

        for iter in range(local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)
                if images.shape[0] == 1:
                    continue
                
                self.local_model.zero_grad()
                log_probs = self.local_model(images)

                proximal_term = 0.0
                for w, w_t in zip(self.local_model.parameters(), global_model.parameters()):
                    proximal_term += (w - w_t).norm(2)

                loss = self.criterion(log_probs, labels.long()) + (0.01 / 2) * proximal_term
                
                loss.backward()
                self.optimizer.step()

                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        self.trainingLoss = sum(epoch_loss) / len(epoch_loss)
        return sum(epoch_loss) / len(epoch_loss)


    def add_generate_datas(self, data_idxs: list):
        self.trainloader.dataset.add_data(data_idxs)
        self.data_num = len(self.trainloader.dataset)

    
