import torch.nn as nn
import torch
from torchvision import transforms
import numpy as np
from torch.nn import functional as F
from PIL import Image
import matplotlib.pyplot as plt
from torch.autograd import Variable
import torch.optim as optim
from myNetwork import *
from torch.utils.data import DataLoader
from scipy.spatial.distance import pdist
from sklearn.cluster import KMeans, DBSCAN
from sklearn import preprocessing
from itertools import combinations
import random
from Fed_utils import * 


class PCLoss(nn.Module):
    def __init__(self, num_classes, scale, device):
        super(PCLoss, self).__init__()
        self.soft_plus = nn.Softplus()
        self.label = torch.LongTensor([i for i in range(num_classes)]).to(device)
        self.scale = scale
        self.device = device
        
    def forward(self, feature, target, proxy):
        '''
        feature: (N, dim)
        proxy: (C, dim)
        '''
        feature = F.normalize(feature, p=2, dim=1)
        pred = F.linear(feature, F.normalize(proxy, p=2, dim=1))  
        
        label = (self.label.unsqueeze(1) == target.unsqueeze(0))   
        pred_p = torch.masked_select(pred, label.transpose(1, 0))    # (N)   positive pair
        pred_p = pred_p.unsqueeze(1)
        pred_n = torch.masked_select(pred, ~label.transpose(1, 0)).view(feature.size(0), -1)  # (N, C-1) negative pair of anchor and proxy
        
        feature = torch.matmul(feature, feature.transpose(1, 0))  
        label_matrix = target.unsqueeze(1) == target.unsqueeze(0)  
        
        feature = feature * ~label_matrix  
        feature = feature.masked_fill(feature < 1e-6, -np.inf)
        
        logits = torch.cat([pred_p, pred_n, feature], dim=1)  
        label = torch.zeros(logits.size(0), dtype=torch.long).to(self.device)
        loss = F.nll_loss(F.log_softmax(logits / self.scale, dim=1), label)
        return loss

class FedNovel_model:
    def __init__(self, args, numclass, feature_extractor, batch_size, task_classes, epochs, learning_rate, train_set, device):
        super(FedNovel_model, self).__init__()
        self.args = args
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.model = network(numclass, feature_extractor, args)

        self.numclass = 0
        self.old_model = None
        self.train_dataset = train_set

        self.batchsize = batch_size
        self.task_classes = task_classes
        self.learned_classes = []

        self.train_loader = None
        self.current_class = None
        self.task_id = -1
        self.device = device

    # get incremental train data
    def beforeTrain(self, task_id_new, data_ids=None):
        if task_id_new != self.task_id:
            self.task_id = task_id_new
            self.numclass = self.task_classes[0]
        self.train_loader = self._get_train_and_test_dataloader(data_ids, False)

    def _get_train_and_test_dataloader(self, data_ids, mix):
        self.train_dataset.getTrainData(data_ids)
        self.current_class = set(self.train_dataset.TrainLabels)

        train_loader = DataLoader(dataset=self.train_dataset,
                                  shuffle=True,
                                  batch_size=self.batchsize,
                                  num_workers=8,
                                  pin_memory=True)

        return train_loader

    # train model
    def train(self, ep_g, model_old):
        self.model = model_to_device(self.model, False, self.device)
        opt = optim.SGD(self.model.parameters(), lr=self.learning_rate, weight_decay=0.00001)
        self.model.train()
        
        if model_old != None:
            self.old_model = model_old
        
        if self.old_model != None:
            log_print('load old model', self.args.out_file)
            self.old_model = model_to_device(self.old_model, False, self.device)
            self.old_model.eval()
        
        for epoch in range(self.epochs):
            for step, (ind, imgs, label) in enumerate(self.train_loader):
                imgs, label = (imgs[0].cuda(self.device), imgs[1].cuda(self.device)), label.cuda(self.device)
                loss_value = self._compute_loss(imgs, label)
                opt.zero_grad()
                loss_value.backward()
                opt.step()
            
            # EMA 
            if self.old_model != None:
                alpha = self.args.ema_alpha
                model_s = self.old_model.feature.state_dict()
                model_d = self.model.feature.state_dict()
                for k in model_d.keys():
                    model_d[k] =  alpha * model_s[k] + (1-alpha) * model_d[k]
                
                self.model.feature.load_state_dict(model_d)

    def _compute_loss(self, imgs, label):
        features = self.model.feature_extractor(imgs[0])
        output = self.model.fc(features) 

        features2 = self.model.feature_extractor(imgs[1])
        output2 = self.model.fc(features2)

        if self.task_id == 0:
            loss_cur = nn.CrossEntropyLoss()(output, label)
            loss_pc = PCLoss(self.numclass, self.args.T, self.device)(self.model.encoder(features), label, self.model.fc.weight)
            return loss_cur + 0.1 * loss_pc
        
        else:
            ## SWL
            sim_mat = output[:, self.model.old_class_num:] / 0.07
            s_dist = F.softmax(sim_mat, dim=1)
            cost_mat = self.EuclideanDistances(features, self.model.fc.weight[self.model.old_class_num:])
            loss_gca = (cost_mat * s_dist).sum(1).mean()

            loss = loss_gca
            return loss

    def EuclideanDistances(self, a, b):
        sq_a = a**2
        sum_sq_a = torch.sum(sq_a,dim=1).unsqueeze(1)  # m->[m, 1]
        sq_b = b**2
        sum_sq_b = torch.sum(sq_b,dim=1).unsqueeze(0)  # n->[1, n]
        bt = b.t()
        return torch.sqrt(sum_sq_a+sum_sq_b-2*a.mm(bt))

    def local_clustering_start(self):
        novel_label, novel_fea = [], []
        with torch.no_grad():
            for step, (indexs, images, target) in enumerate(self.train_loader):
                images, target = images[0].cuda(self.device), target.cuda(self.device)
                fea = self.model.feature_extractor(images)
                novel_label.append(target.cpu())
                novel_fea.append(fea.cpu())
        
        novel_label = torch.cat(novel_label, dim=0)
        novel_fea = torch.cat(novel_fea, dim=0)

        novel_fea = novel_fea.numpy()
        novel_label = novel_label.numpy()
        
        km = KMeans(n_clusters=self.args.task_classes[0]).fit(novel_fea)
        cluster_centers = km.cluster_centers_
        centers_label = []
        for center_t in cluster_centers:
            dis = np.linalg.norm(novel_fea - center_t, axis=1)
            min_ind_t = np.argmin(dis)
            centers_label.append(novel_label[min_ind_t])

        return cluster_centers