import pdb
import numpy as np
import torch
import logging
from data_split import *
import copy
import torch.optim as optim
import sys
import torch.nn as nn
from tqdm import tqdm
from utils import *
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)


def local_train(args, train_dataset, test_loader, user_groups, models, model_folder, model_name, device):
    for idx in range(args.n_clients):
        logger.info('Training client %s' % str(idx))
        print('Training client %s' % str(idx))

        local_client = SingleLLocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], test=test_loader,
                                          device=device)
        local_client.update_weights(models[idx])
        # save model
        save_model_path = os.path.join(model_folder, f'{model_name}_{idx}.pth')
        models[idx].to('cpu')
        torch.save(models[idx].state_dict(), save_model_path)


class SingleLLocalUpdate(object):
    def __init__(self, args, dataset, idxs, test, device):
        self.args = args
        self.train_loader = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.bs, shuffle=True,
                                       num_workers=args.num_workers)
        self.test_loader = test
        self.device = device

    def update_weights(self, model):
        model.to(self.device)
        optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=1e-4)
        criterion = nn.CrossEntropyLoss().to(self.device)
        for epoch in range(self.args.local_epochs):
            model.train()
            epoch_loss_collector = []
            train_bar = tqdm(self.train_loader, file=sys.stdout)
            for batch_idx, (images, labels) in enumerate(train_bar):
                images, labels = images.to(self.device), labels.to(self.device)
                model.zero_grad()
                output = model(images, torch.LongTensor(list(range(197))).repeat(images.shape[0], 1).to(self.device),
                               torch.LongTensor(list(range(197))).repeat(images.shape[0], 1).to(self.device))
                loss = criterion(output[0], labels)
                loss.backward()
                optimizer.step()
                epoch_loss_collector.append(loss.item())

            epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector)
            logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss))
            print('Epoch: %d Loss: %f' % (epoch, epoch_loss))
            if (epoch + 1) % self.args.interval == 0:
                train_acc = compute_accuracy(model, self.train_loader, info='train data', device=self.device)
                test_acc = compute_accuracy(model, self.test_loader, info='test data', device=self.device)

