import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, Dataset
import moco.builder
import moco.loader
from models.SaveModel import SaveModel
import os
np.set_printoptions(edgeitems=30)
torch.set_printoptions(edgeitems=30)

class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)


    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

class Client:
    def __init__(self, device, local_model, global_prompt, train_dataset, test_dataset, train_idxs, test_idxs, args, index, logger=None):
        self.device = device
        self.args = args
        self.index = index
        self.local_model = local_model
        self.logger = logger

        self.moco = moco.builder.MoCo(
            self.local_model,
            args,
            args.moco_dim,
            args.moco_k,
            args.moco_m,
            args.moco_t,
            args.mlp,
            args.num_heads,
        )
        self.moco.global_prompt = global_prompt
        if self.args.prompt_num_tokens_p > 0:
            self.personalized_prompt = nn.Parameter(torch.randn(1, self.args.prompt_num_tokens_p, args.feature_dim).to(self.device))
            torch.nn.init.xavier_uniform(self.personalized_prompt, gain=1)
        else:
            self.personalized_prompt = None
        self.classifier = nn.Linear(args.feature_dim, args.num_classes)

        self.trainingLoss_contrastive = None
        self.trainingLoss_classification = None
        self.trainAcc = None
        self.testingLoss = None
        self.testingAcc = None


        self.trainloader, self.validloader, self.testloader = self.train_val_test(
            train_dataset, list(train_idxs), test_dataset, list(test_idxs))


        # define optimizer
        if self.args.optimizer == 'sgd':
            self.optimizer_encoder = torch.optim.SGD(self.moco.encoder_q.parameters(), lr=self.args.lr, momentum=self.args.momentum)
            self.scheduler_encoder = torch.optim.lr_scheduler.StepLR(self.optimizer_encoder, step_size=self.args.decay_step, gamma=self.args.lr_decay)
            if self.args.use_attention == 1:
                self.optimizer_attention = torch.optim.SGD(self.moco.attention.parameters(), lr=self.args.attention_lr)
                self.scheduler_attention = torch.optim.lr_scheduler.StepLR(self.optimizer_attention, step_size=self.args.decay_step, gamma=self.args.lr_decay)
            if self.moco.global_prompt is not None:
                self.optimizer_prompt_g = torch.optim.SGD([self.moco.global_prompt], lr=self.args.lr)
                self.scheduler_prompt_g = torch.optim.lr_scheduler.StepLR(self.optimizer_prompt_g, step_size=self.args.decay_step, gamma=self.args.lr_decay)

            if self.args.proj: self.optimizer_fc_q = torch.optim.SGD(self.moco.fc_q.parameters(), lr=self.args.mlp_lr,
                                                  momentum=self.args.momentum)

            self.optimizer_classifier = torch.optim.SGD(self.classifier.parameters(), lr=self.args.classification_lr,
                                            momentum=self.args.momentum)
            self.scheduler_classifier = torch.optim.lr_scheduler.StepLR(self.optimizer_classifier, step_size=self.args.decay_step, gamma=self.args.lr_decay)
            if self.personalized_prompt is not None:
                self.optimizer_prompt_p = torch.optim.SGD([self.personalized_prompt], lr=self.args.classification_lr)
                self.scheduler_prompt_p = torch.optim.lr_scheduler.StepLR(self.optimizer_prompt_p, step_size=self.args.decay_step, gamma=self.args.lr_decay)
        else:
            raise NotImplementedError

        # define Loss function
        self.criterion = torch.nn.CrossEntropyLoss().to(self.device)

    def train_val_test(self, train_dataset, train_idxs, test_dataset, test_idxs):
        """
        Returcns train, validation and test dataloaders for a given dataset
        and user indexes.
        """
        trainloader = DataLoader(DatasetSplit(train_dataset, train_idxs),
                                 batch_size=self.args.local_bs, shuffle=True, num_workers=self.args.num_workers)
        validloader = None

        testloader = DataLoader(DatasetSplit(test_dataset, test_idxs),
                                batch_size=int(len(test_idxs) / 10), shuffle=False)

        return trainloader, validloader, testloader

    def train(self):
        self.moco.train()
        self.classifier.train()
        self.moco.to(self.device)
        self.classifier.to(self.device)

        epoch_loss_contrastive = []
        epoch_loss_classification = []
        total, correct = 0.0, 0.0


        # lerning rate decay
        self.scheduler_encoder.step()
        if self.args.use_attention == 1: self.scheduler_attention.step()
        if self.moco.global_prompt is not None:
            self.scheduler_prompt_g.step()


        for iter in range(self.args.contrastive_ep):
            batch_loss_contrastive = []
            batch_loss_classification = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images[0], images[1], images[2], labels = images[0].to(self.device), images[1].to(self.device), images[2].to(self.device), labels.to(
                    self.device)
                self.optimizer_encoder.zero_grad()
                if self.args.use_attention == 1: self.optimizer_attention.zero_grad()
                if self.args.proj: self.optimizer_fc_q.zero_grad()
                if self.personalized_prompt is not None:
                    self.optimizer_prompt_p.zero_grad()

                output, target, _, _ = self.moco(im_q=images[0], im_k=images[1])
                contrastive_loss = self.criterion(output, target) * self.args.lamda1
                contrastive_loss.backward()

                features = self.moco.get_feature(img_q=images[2], personalized_prompt=self.personalized_prompt, if_update_encoder=False)
                logits = self.classifier(features)
                classification_loss = self.criterion(logits, labels) * self.args.lamda
                classification_loss.backward()

                self.optimizer_encoder.step()
                if self.args.use_attention == 1: self.optimizer_attention.step()
                if self.args.proj: self.optimizer_fc_q.step()
                if self.personalized_prompt is not None:
                    self.optimizer_prompt_p.step()

                batch_loss_contrastive.append(contrastive_loss.item())
                batch_loss_classification.append(classification_loss.item())

                _, pred_labels = torch.max(logits, 1)
                pred_labels = pred_labels.view(-1)
                correct += torch.sum(torch.eq(pred_labels, labels.long())).item()
                total += len(labels)

        for iter in range(self.args.local_ep):
            batch_loss_contrastive = []
            batch_loss_classification = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images[0], images[1], images[2], labels = images[0].to(self.device), images[1].to(self.device), images[
                    2].to(self.device), labels.to(self.device)

                self.optimizer_encoder.zero_grad()
                if self.args.use_attention == 1: self.optimizer_attention.zero_grad()
                self.optimizer_classifier.zero_grad()

                output, target, encoder_feature, prompt_feature = self.moco(im_q=images[0], im_k=images[1], if_update_encoder=False)
                contrastive_loss = self.criterion(output, target) * self.args.lamda1


                if self.moco.global_prompt is not None:
                    self.optimizer_prompt_g.zero_grad()
                    contrastive_loss.backward()

                features = self.moco.get_feature(img_q=images[2], personalized_prompt=self.personalized_prompt)
                logits = self.classifier(features)
                classification_loss = self.criterion(logits, labels) * self.args.lamda
                classification_loss.backward()

                self.optimizer_encoder.step()
                if self.args.use_attention == 1: self.optimizer_attention.step()
                self.optimizer_classifier.step()
                if self.moco.global_prompt is not None:
                    self.optimizer_prompt_g.step()

                batch_loss_contrastive.append(contrastive_loss.item())
                batch_loss_classification.append(classification_loss.item())

                _, pred_labels = torch.max(logits, 1)
                pred_labels = pred_labels.view(-1)
                correct += torch.sum(torch.eq(pred_labels, labels.long())).item()
                total += len(labels)

            epoch_loss_contrastive.append(sum(batch_loss_contrastive)/len(batch_loss_contrastive))
            epoch_loss_classification.append(sum(batch_loss_classification)/len(batch_loss_classification))



        self.trainingLoss_contrastive = sum(epoch_loss_contrastive) / len(epoch_loss_contrastive)
        self.trainingLoss_classification = sum(epoch_loss_classification) / len(epoch_loss_classification)
        self.trainAcc = correct / total

        self.moco.to('cpu')
        self.classifier.to('cpu')

        return


    def inference_contrastive(self):
        accuracy, loss = 0, 0
        return accuracy, loss


    def inference(self):
        self.moco.eval()
        self.moco.to(self.device)
        self.classifier.to(self.device)
        loss, total, correct = 0.0, 0.0, 0.0

        count = 0
        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(self.testloader):
                images, labels = images.to(self.device), labels.to(self.device)

                # inference
                features = self.moco.get_feature(img_q=images, personalized_prompt=self.personalized_prompt)
                outputs = self.classifier(features)
                batch_loss = self.criterion(outputs, labels.long())
                loss += batch_loss.item()

                # prediction
                _, pred_labels = torch.max(outputs, 1)
                pred_labels = pred_labels.view(-1)
                correct += torch.sum(torch.eq(pred_labels, labels.long())).item()
                total += len(labels)
                count += 1

        accuracy = correct / total
        loss = loss / count
        self.testingAcc, self.testingLoss = accuracy, loss

        self.moco.to('cpu')
        self.classifier.to('cpu')
        return accuracy, loss

    def inference_train(self):
        accuracy, loss = 0, 0
        return accuracy, loss

    def inference_specific_dataset(self, dataloader, prompt=None):
        self.moco.eval()
        self.moco.to(self.device)
        self.classifier.to(self.device)
        loss, total, correct = 0.0, 0.0, 0.0

        count = 0
        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(dataloader):
                images, labels = images.to(self.device), labels.to(self.device)

                # inference
                features = self.moco.get_feature(img_q=images, personalized_prompt=prompt)
                outputs = self.classifier(features)
                batch_loss = self.criterion(outputs, labels.long())
                loss += batch_loss.item()

                # prediction
                _, pred_labels = torch.max(outputs, 1)
                pred_labels = pred_labels.view(-1)
                correct += torch.sum(torch.eq(pred_labels, labels.long())).item()
                total += len(labels)
                count += 1

        accuracy = correct / total
        loss = loss / count

        self.moco.to('cpu')
        self.classifier.to('cpu')
        return accuracy, loss

    def save_model(self):
        model = SaveModel(MoCo=self.moco, personalized_prompt=self.personalized_prompt, classificer=self.classifier)
        model_dir = self.args.root_file + '/save_models'
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

        model_file_path = model_dir + '/client_' + str(self.index) + '.pth'
        torch.save(model, model_file_path)
        return 




