import numpy as np
import torch
import copy
from torch.utils.data import DataLoader
from sim.utils.utils import AverageMeter, accuracy
import random

# Compute the softmax
def all_softmax(logits, labels, temperature):
    exp_logits = torch.exp(logits / temperature)
    softmax = exp_logits / (exp_logits.sum(dim=1, keepdim=True) + 1e-8)
    return softmax

def compute_loss(z_l, z_g, y, temperature):
    q_T_l = all_softmax(z_l, y, temperature)
    q_T_g = all_softmax(z_g, y, temperature)

    # Compute the KL divergence loss
    kl_div = (q_T_g * (torch.log(q_T_g + 1e-8) - torch.log(q_T_l + 1e-8))).sum(dim=1)
    return kl_div.mean()

# Compute the not true softmax
def not_true_softmax(logits, labels, temperature):
    mask = torch.ones_like(logits).scatter_(1, labels.unsqueeze(1), 0)
    exp_logits = torch.exp(logits / temperature) * mask
    softmax = exp_logits / (exp_logits.sum(dim=1, keepdim=True) + 1e-8)
    return softmax

def compute_ntd_loss(z_l, z_g, y, temperature):
    q_T_l = not_true_softmax(z_l, y, temperature)
    q_T_g = not_true_softmax(z_g, y, temperature)
    
    # Compute the KL divergence loss
    kl_div = (q_T_g * (torch.log(q_T_g + 1e-8) - torch.log(q_T_l + 1e-8))).sum(dim=1)
    return kl_div.mean()

###### CLIENT ######
class SFLClient():
    def __init__(self):
        super(SFLClient, self).__init__()
    
    def setup_criterion(self, criterion):
        self.criterion = criterion

    def setup_train_dataset(self, dataset):
        self.train_feddataset = dataset
    
    def setup_test_dataset(self, dataset):
        self.test_dataset = dataset

    def setup_optim_kit(self, optim_kit):
        self.optim_kit = optim_kit

    def local_update_step(self, local_dataset, model, num_steps, device, **kwargs):
        data_loader = DataLoader(local_dataset, batch_size=self.optim_kit.batch_size, shuffle=True)
        optimizer = self.optim_kit.optim(model.parameters(), **self.optim_kit.settings)
        model = model.to(device)
        prev_model = copy.deepcopy(model)
        model.train()
        
        step_count = 0
        while(True):
            for input, target in data_loader:
                input = input.to(device)
                target = target.to(device)
                output = model(input)
                loss = self.criterion(output, target.view(-1))
                optimizer.zero_grad()
                loss.backward()

                if 'clip' in kwargs.keys() and kwargs['clip'] > 0:
                    total_norm = torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=kwargs['clip'])

                optimizer.step()
            step_count += 1
            if (step_count >= num_steps):
                break
        with torch.no_grad():
            curr_vec = torch.nn.utils.parameters_to_vector(model.parameters())
            prev_vec = torch.nn.utils.parameters_to_vector(prev_model.parameters())
            delta_vec = curr_vec - prev_vec
            assert step_count == num_steps            
            # add log
            local_log = {}
            local_log = {'total_norm': total_norm} if 'clip' in kwargs.keys() and kwargs['clip'] > 0 else local_log
            return curr_vec, delta_vec
        
    def local_update_step_flash(self, local_dataset, model, num_steps, device, **kwargs):
        data_loader = DataLoader(local_dataset, batch_size=self.optim_kit.batch_size, shuffle=True)
        optimizer = self.optim_kit.optim(model.parameters(), **self.optim_kit.settings)
        model = model.to(device)
        prev_model = copy.deepcopy(model)
        model.train()
        prev_loss = -1
        step_count = 0
        while(True):
            for input, target in data_loader:
                input = input.to(device)
                target = target.to(device)
                output = model(input)
                loss = self.criterion(output, target.view(-1))
                optimizer.zero_grad()
                loss.backward()
                if 'clip' in kwargs.keys() and kwargs['clip'] > 0:
                    total_norm = torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=kwargs['clip'])

                optimizer.step()
            if prev_loss != -1:
                print('prev_loss: {}, loss: {}'.format(prev_loss, loss.item()))
                if prev_loss - loss.item() < 0.001 / (step_count + 1):
                    break
            prev_loss = loss.item()
            step_count += 1
            if (step_count >= num_steps):
                break
        with torch.no_grad():
            curr_vec = torch.nn.utils.parameters_to_vector(model.parameters())
            prev_vec = torch.nn.utils.parameters_to_vector(prev_model.parameters())
            delta_vec = curr_vec - prev_vec
            # assert step_count == num_steps            
            # add log
            local_log = {}
            local_log = {'total_norm': total_norm} if 'clip' in kwargs.keys() and kwargs['clip'] > 0 else local_log
            return curr_vec, delta_vec
        
    def local_update_step_omd(self, local_dataset, model, num_steps, device, **kwargs):
        data_loader = DataLoader(local_dataset, batch_size=self.optim_kit.batch_size, shuffle=True)
        optimizer = self.optim_kit.optim(model.parameters(), **self.optim_kit.settings)
        model = model.to(device)
        prev_model = copy.deepcopy(model)
        model.train()
        
        step_count = 0
        while(True):
            for input, target in data_loader:
                input = input.to(device)
                target = target.to(device)
                output = model(input)
                loss = self.criterion(output, target.view(-1)) + 0.01 * torch.norm(torch.nn.utils.parameters_to_vector(model.parameters()) - torch.nn.utils.parameters_to_vector(prev_model.parameters()), 2)
                optimizer.zero_grad()
                loss.backward()

                if 'clip' in kwargs.keys() and kwargs['clip'] > 0:
                    total_norm = torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=kwargs['clip'])

                optimizer.step()
            step_count += 1
            if (step_count >= num_steps):
                break
        with torch.no_grad():
            curr_vec = torch.nn.utils.parameters_to_vector(model.parameters())
            prev_vec = torch.nn.utils.parameters_to_vector(prev_model.parameters())
            delta_vec = curr_vec - prev_vec
            assert step_count == num_steps            
            # add log
            local_log = {}
            local_log = {'total_norm': total_norm} if 'clip' in kwargs.keys() and kwargs['clip'] > 0 else local_log
            return curr_vec, delta_vec
        
    def local_update_step_kd(self, local_dataset, model, g_model, num_steps, device, a, b, **kwargs):
        data_loader = DataLoader(local_dataset, batch_size=self.optim_kit.batch_size, shuffle=True)
        optimizer = self.optim_kit.optim(model.parameters(), **self.optim_kit.settings)

        prev_model = copy.deepcopy(model)
        prev_g_model = copy.deepcopy(g_model)
        model.train()
        
        step_count = 0
        while(True):
            for input, target in data_loader:
                input = input.to(device)
                target = target.to(device)
                output = model(input)
                z_p = prev_model(input)
                z_g = prev_g_model(input)
                # Compute the KL divergence loss
                cl_loss = compute_loss(output, z_p, target, 2.0)
                server_loss = compute_loss(output, z_g, target, 2.0)
                loss = self.criterion(output, target.view(-1)) + a * cl_loss * (2.0 ** 2) + b * server_loss * (2.0 ** 2)
                optimizer.zero_grad()
                loss.backward()

                if 'clip' in kwargs.keys() and kwargs['clip'] > 0:
                    total_norm = torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=kwargs['clip'])

                optimizer.step()
            step_count += 1
            if (step_count >= num_steps):
                break
        with torch.no_grad():
            curr_vec = torch.nn.utils.parameters_to_vector(model.parameters())
            prev_vec = torch.nn.utils.parameters_to_vector(prev_model.parameters())
            delta_vec = curr_vec - prev_vec
            assert step_count == num_steps            
            # add log
            local_log = {}
            local_log = {'total_norm': total_norm} if 'clip' in kwargs.keys() and kwargs['clip'] > 0 else local_log
            return curr_vec, delta_vec

    def local_update_step_ewc(self, local_dataset, model, g_model, num_steps, device, a, fisher_information, **kwargs):
        data_loader = DataLoader(local_dataset, batch_size=self.optim_kit.batch_size, shuffle=True)
        optimizer = self.optim_kit.optim(model.parameters(), **self.optim_kit.settings)

        prev_model = copy.deepcopy(model)
        next_fisher_information = {}
        for param_name, param in model.named_parameters():
            next_fisher_information[param_name] = torch.zeros_like(param)
        model.train()
        
        step_count = 0
        while(True):
            for input, target in data_loader:
                input = input.to(device)
                target = target.to(device)
                output = model(input)
                ewc_loss = 0.0
                for param_name, param in model.named_parameters():
                    if param_name in fisher_information:
                        fisher = fisher_information[param_name]
                        prev_param = prev_model.state_dict()[param_name]
                        
                        ewc_loss += (fisher * (param - prev_param).pow(2)).sum()

                loss = self.criterion(output, target.view(-1)) + a * ewc_loss
                optimizer.zero_grad()
                loss.backward()

                for param_name, param in model.named_parameters():
                    next_fisher_information[param_name] += param.grad ** 2

                if 'clip' in kwargs.keys() and kwargs['clip'] > 0:
                    total_norm = torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=kwargs['clip'])

                optimizer.step()
            step_count += 1
            if (step_count >= num_steps):
                break
        with torch.no_grad():
            curr_vec = torch.nn.utils.parameters_to_vector(model.parameters())
            prev_vec = torch.nn.utils.parameters_to_vector(prev_model.parameters())
            delta_vec = curr_vec - prev_vec
            assert step_count == num_steps            
            # add log
            local_log = {}
            local_log = {'total_norm': total_norm} if 'clip' in kwargs.keys() and kwargs['clip'] > 0 else local_log
            return curr_vec, next_fisher_information

    def local_update_step_prox(self, local_dataset, model, g_model, num_steps, device, a, **kwargs):
        data_loader = DataLoader(local_dataset, batch_size=self.optim_kit.batch_size, shuffle=True)
        optimizer = self.optim_kit.optim(model.parameters(), **self.optim_kit.settings)

        prev_model = copy.deepcopy(model)
        prev_g_model = copy.deepcopy(g_model)
        model.train()
        
        step_count = 0
        while(True):
            for input, target in data_loader:
                input = input.to(device)
                target = target.to(device)
                output = model(input)
                reg_term = torch.norm(torch.nn.utils.parameters_to_vector(model.parameters()) - torch.nn.utils.parameters_to_vector(prev_g_model.parameters()), 2)
                loss = self.criterion(output, target.view(-1)) + a * reg_term / 2
                optimizer.zero_grad()
                loss.backward()

                if 'clip' in kwargs.keys() and kwargs['clip'] > 0:
                    total_norm = torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=kwargs['clip'])

                optimizer.step()
            step_count += 1
            if (step_count >= num_steps):
                break
        with torch.no_grad():
            curr_vec = torch.nn.utils.parameters_to_vector(model.parameters())
            prev_vec = torch.nn.utils.parameters_to_vector(prev_model.parameters())
            delta_vec = curr_vec - prev_vec
            assert step_count == num_steps            
            # add log
            local_log = {}
            local_log = {'total_norm': total_norm} if 'clip' in kwargs.keys() and kwargs['clip'] > 0 else local_log
            return curr_vec, local_log
        
    def compute_fisher_matrix(self, model, dataloader, criterion, device):
        fisher_matrix = {name: torch.zeros_like(param) for name, param in model.named_parameters()}
        
        model.eval()
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            model.zero_grad()
            
            output = model(data)
            loss = criterion(output, target)
            loss.backward()

            for name, param in model.named_parameters():
                fisher_matrix[name] += (param.grad ** 2) / len(dataloader)

        return fisher_matrix

    def local_update_step_curv(self, local_dataset, model, g_model, num_steps, device, a, **kwargs):
        data_loader = DataLoader(local_dataset, batch_size=self.optim_kit.batch_size, shuffle=True)
        optimizer = self.optim_kit.optim(model.parameters(), **self.optim_kit.settings)
        fisher_information = self.compute_fisher_matrix(g_model, data_loader, self.criterion, device)
        prev_model = copy.deepcopy(model)
        prev_g_model = copy.deepcopy(g_model)
        model.train()
        
        step_count = 0
        while(True):
            for input, target in data_loader:
                input = input.to(device)
                target = target.to(device)
                output = model(input)
                ewc_loss = 0.0
                for param_name, param in model.named_parameters():
                    if param_name in fisher_information:
                        fisher = fisher_information[param_name]
                        prev_param = prev_g_model.state_dict()[param_name]

                        ewc_loss += (fisher * (param - prev_param).pow(2)).sum()

                loss = self.criterion(output, target.view(-1)) + a * ewc_loss
                optimizer.zero_grad()
                loss.backward()

                if 'clip' in kwargs.keys() and kwargs['clip'] > 0:
                    total_norm = torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=kwargs['clip'])

                optimizer.step()
            step_count += 1
            if (step_count >= num_steps):
                break
        with torch.no_grad():
            curr_vec = torch.nn.utils.parameters_to_vector(model.parameters())
            prev_vec = torch.nn.utils.parameters_to_vector(prev_model.parameters())
            delta_vec = curr_vec - prev_vec
            assert step_count == num_steps            
            # add log
            local_log = {}
            local_log = {'total_norm': total_norm} if 'clip' in kwargs.keys() and kwargs['clip'] > 0 else local_log
            return curr_vec, local_log    

    def local_update_step_ntd(self, local_dataset, model, g_model, num_steps, device, a, **kwargs):
        data_loader = DataLoader(local_dataset, batch_size=self.optim_kit.batch_size, shuffle=True)
        optimizer = self.optim_kit.optim(model.parameters(), **self.optim_kit.settings)

        prev_model = copy.deepcopy(model)
        prev_g_model = copy.deepcopy(g_model)
        model.train()
        
        step_count = 0
        while(True):
            for input, target in data_loader:
                input = input.to(device)
                target = target.to(device)
                output = model(input)
                loss = self.criterion(output, target.view(-1))
                z_g = prev_g_model(input)
                ntd_loss = compute_ntd_loss(output, z_g, target, 2.0)
                loss += a * ntd_loss * (2.0 ** 2)
                optimizer.zero_grad()
                loss.backward()

                if 'clip' in kwargs.keys() and kwargs['clip'] > 0:
                    total_norm = torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=kwargs['clip'])

                optimizer.step()
            step_count += 1
            if (step_count >= num_steps):
                break
        with torch.no_grad():
            curr_vec = torch.nn.utils.parameters_to_vector(model.parameters())
            prev_vec = torch.nn.utils.parameters_to_vector(prev_model.parameters())
            delta_vec = curr_vec - prev_vec
            assert step_count == num_steps            
            # add log
            local_log = {}
            local_log = {'total_norm': total_norm} if 'clip' in kwargs.keys() and kwargs['clip'] > 0 else local_log
            return curr_vec, local_log

    def local_update_epoch(self, client_model,data, epoch, batchsize):
        pass

    def evaluate_dataset(self, model, dataset, device):
        '''Evaluate on the given dataset'''
        data_loader = DataLoader(dataset, batch_size=self.optim_kit.batch_size, shuffle=False)
        
        model.eval()
        with torch.no_grad():

            losses = AverageMeter()
            top1 = AverageMeter()
            top5 = AverageMeter()
            for input, target in data_loader:
                input = input.to(device)
                target = target.to(device)
                output = model(input)
                loss = self.criterion(output, target.view(-1))
                acc1, acc5 = accuracy(output, target.view(-1), topk=[1,5])
                losses.update(loss.item(), target.size(0))
                top1.update(acc1.item(), target.size(0))
                top5.update(acc5.item(), target.size(0))
            return losses, top1, top5

###### SERVER ######
class SFLServer():
    def __init__(self):
        super(SFLServer, self).__init__()

    def setup_model(self, model):
        self.global_model = model
    
    def setup_temp_model(self, model):
        self.temp_model = model
    
    def setup_optim_settings(self, **settings):
        self.lr = settings['lr']
        
    def select_clients_prob(self, num_clients, probs):
        indices = list(range(num_clients))
    
        for _ in range(1000):
            chosen = [i for i in indices if random.random() < probs[i]]
            if chosen:
                return chosen
        
        return [random.choices(indices, weights=probs, k=1)[0]]
        
    
    def aggregate_update(self, local_delta_list):
        with torch.no_grad():
            delta_avg = torch.zeros_like(local_delta_list[0])
            for local_delta in local_delta_list:
                delta_avg.add_(local_delta)
            delta_avg = delta_avg.div_(len(local_delta_list))
        return delta_avg

    def global_update(self):
        with torch.no_grad():
            param_vec_curr = torch.nn.utils.parameters_to_vector(self.global_model.parameters()) + self.lr * self.delta_avg 
            return param_vec_curr

 
        