import os
import math
import copy
import time
import torch
import numpy as np
import torch.nn as nn
import torch.cuda as cuda
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader
from torch.nn.utils.clip_grad import clip_grad_norm_
from utils import get_optimizer, get_net_builder
from datasets import fetch_dataset, split_dataset, SubDataset, fetch_os_dataset
from utils import AverageMeter, mixup_data, make_batchnorm_stats

class ServerBase(object):
    def __init__(self, args, Client) -> None:
        self.args = args
        self.net = args.net
        self.algorithm = args.algorithm
        self.device = torch.device('cuda:0')
        self.warmup_epochs = args.warmup_epochs
        self.global_rounds = args.global_rounds
        self.current_round = 0
        self.num_clients = args.num_clients
        self.num_join_clients = int(self.num_clients * args.join_ratio)
        self.clients = []
        self.logger = args.logger
        self.printer = args.printer
        self.save_dir = args.save_dir
        self.exp_tag = args.exp_tag
        self.load_path = args.load_path
        self.data_shape = args.data_shape
        self.num_classes = args.num_classes
        self.clip_grad = args.clip_grad
        self.selection = None
        self.agg = args.agg
        self.best_acc = 0
        self.selected_clients = []
        self.local_steps = args.s_local_steps
        self.batch_size = args.batch_size
        self.ce_loss = nn.CrossEntropyLoss()
        self.sBN = args.sBN
        self.data_idx, self.dataset = self.make_dataset(args)
        self.train_loader = DataLoader(self.dataset['lb_set'], batch_size=self.batch_size, shuffle=True)
        self.global_model = self.make_model()
        self.global_optimizer = get_optimizer(self.global_model, optim_name='SGD', lr=1, momentum=0.5, weight_decay=0, nesterov=False)
        self.optimizer = get_optimizer(self.global_model, optim_name=args.optim, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.global_rounds, eta_min=0.001)
        self.make_client(args, Client)
        if self.sBN:
            self.batchnorm_dataset = self.make_norm_dataset()

        
    def make_model(self):
        self.printer.debug(f'make model: {self.net}')
        net_builder = get_net_builder(self.net)
        model = net_builder(self.data_shape, self.num_classes).to(self.device)
        return model    

    def make_norm_dataset(self):
        dataset = copy.deepcopy(self.dataset['lb_set'])
        dataset.data = np.concatenate([dataset.data, self.dataset['ulb_set'].data], axis=0)
        dataset.targets = np.concatenate([dataset.targets, self.dataset['ulb_set'].targets], axis=0)
        dataset.transform = self.dataset['test'].transform
        return dataset
        

    def make_dataset(self, args):
        self.printer.debug('make dataset')
        if args.num_seen_class < args.num_classes:
            datasets = fetch_os_dataset(args.data_dir, args.dataset, args.num_labels, args.num_seen_class, close_set=args.close_train, train=True)
            testset = fetch_os_dataset(args.data_dir, args.dataset, args.num_labels, args.num_seen_class, close_set=args.close_test, train=False)
            self.num_classes = args.num_seen_class
        else:
            datasets = fetch_dataset(args.data_dir, args.dataset, args.num_labels, train=True)
            testset = fetch_dataset(args.data_dir, args.dataset, args.num_labels, train=False)
        lb_set = datasets['lb_set']
        ulb_set = datasets['ulb_set']
        data_idx = split_dataset(datasets['ulb_set'], args, self.num_clients)
        return data_idx, {
            'lb_set': lb_set,
            'ulb_set': ulb_set,
            'test': testset['ulb_set']
        }

    def make_client(self, args, clientObj):
        self.printer.debug('make clients')
        # open-set case
        ulb_set = self.dataset['ulb_set']
        ulb_classes = max(self.num_classes, ulb_set.classes)
        matrix = np.zeros((self.num_clients, ulb_classes), dtype=np.int32)
        for i in range(self.num_clients):
            idx_i = self.data_idx[i]
            targets = ulb_set.targets
            idx_i = np.array(idx_i, dtype=np.int32)
            for j in range(ulb_classes):
                matrix[i, j] = np.sum(targets[idx_i] == j)
            client_set = SubDataset(ulb_set, idx_i)
            client = clientObj(args, i, client_set)
            self.clients.append(client)
        self.statistics = matrix
        self.printer.info(f'----------------label distribution of clients-----------------\n{matrix}')


    def select_clients(self, round_idx):
        return list(np.random.choice(self.num_clients, self.num_join_clients, replace=False))
        # if self.selection is None:
        #     self.selection = [list(np.random.choice(self.num_clients, self.num_join_clients, replace=False)) for i in range(self.global_rounds)]
        # return self.selection[round_idx]
    
    @torch.no_grad()
    def aggregate(self, uploaded_models, weights): 
        """
        aggregate model parameters based on the uploaded weights;
        """   
        self.printer.debug("------Base aggregate_models------")
        if len(uploaded_models) > 0:
            self.printer.info(f'aggregation weights: {weights}')
            with torch.no_grad():
                shadow_model = copy.deepcopy(self.global_model)
                for param in shadow_model.parameters():
                    param.data.zero_()
                
                for w, client_model in zip(weights, uploaded_models):
                    for new_param, param in zip(client_model.parameters(), shadow_model.parameters()):
                        param.data += w * new_param.data.clone()
                
                self.global_optimizer.zero_grad()
                for new_param, param in zip(shadow_model.parameters(), self.global_model.parameters()):
                    param.grad = (param.data - new_param.data).detach()
                self.global_optimizer.step()

                # update batchnorm statistics
                for i in range(len(uploaded_models)):
                    w, client_model = weights[i], uploaded_models[i]
                    for gmodule, lmodule in zip(self.global_model.modules(), client_model.modules()):
                        if isinstance(gmodule, nn.BatchNorm2d):
                            if i == 0:
                                gmodule.running_mean = lmodule.running_mean.clone() * w
                                gmodule.running_var = lmodule.running_var.clone() * w
                            else:
                                gmodule.running_mean += lmodule.running_mean.clone() * w
                                gmodule.running_var += lmodule.running_var.clone() * w
        else:
            self.printer.info('no uploaded models, skip aggregation')
    
    def aggregate_models(self, round_idx):
        """
        get local models from selected clients and aggregate them
        """
        uploaded_models = []
        weights = []
        for i, id in enumerate(self.selected_clients):
            client = self.clients[id]
            if client.util: # local training happens
                model = copy.deepcopy(client.model).to(self.device)
                uploaded_models.append(model)
                if self.agg == 'uniform':
                    weights.append(1)
                elif self.agg == 'weighted':
                    weights.append(client.util)
                elif self.agg == 'loss':
                    ls = self.test_on_trainset(model)
                    weights.append(math.exp(-ls))
                else:
                    raise ValueError(f'invalid aggregation method: {self.agg}')
        wsum = sum(weights)
        weights = [i / wsum for i in weights]
        self.aggregate(uploaded_models, weights)
            


    def training_stats(self, round_idx):
        """
        collect training statistics
        """
        log_dict = {}
        pl_acc, mask_acc, util, samples = 0, 0, 0, 0
        for i, id in enumerate(self.selected_clients):
            client = self.clients[id]
            log = client.logs
            log_dict[id] = log
            pl_acc += log['pl_acc'].mean() * log['samples']
            mask_acc += (log['mask_acc'] * log['util']).mean() * log['samples']
            util += log['util'].mean() * log['samples']
            samples += log['samples']

        for id, logs in log_dict.items():
            log = f'C{id:>2d}:'
            for i in range(logs['pl_acc'].shape[0]):
                log += f'|pl={logs["pl_acc"][i]:.2f}, util={logs["util"][i]:.2f}, mask={logs["mask_acc"][i]:.2f}|'
            self.printer.info(log)
        
        self.logger.log({'pseudo_label_acc': pl_acc / samples * 100}, step=round_idx)
        self.logger.log({'masked_label_acc': mask_acc / (util + 1e-8) * 100}, step=round_idx)
        self.logger.log({'util_ratio': util / samples * 100}, step=round_idx)


    def warmup(self):
        self.global_model.train(True)
        for epoch in range(self.warmup_epochs):
            for i, data in enumerate(self.train_loader):
                x, y = data['x'].to(self.device), data['y'].to(self.device)
                self.optimizer.zero_grad()
                logits = self.global_model(x)
                loss = self.ce_loss(logits, y)
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.global_model.parameters(), self.clip_grad)
                self.optimizer.step()
    @torch.no_grad()
    def test_on_trainset(self, model):
        model.train(False)
        loader = DataLoader(self.dataset['lb_set'], batch_size=100, shuffle=False)
        ce_loss_meter = AverageMeter()
        for i, data in enumerate(loader):
            x, y = data['x'].to(self.device), data['y'].to(self.device)
            logits = model(x)
            loss = self.ce_loss(logits, y)
            ce_loss_meter.update(loss.item(), y.shape[0])
        return ce_loss_meter.avg

    def train(self, round):
        st = time.time()
        self.global_model.train(True)
        ce_loss_meter = AverageMeter()
        acc_meter = AverageMeter()
        for epoch in range(self.local_steps):
            for i, data in enumerate(self.train_loader):
                x, y = data['x'].to(self.device), data['y'].to(self.device)
                self.optimizer.zero_grad()
                logits = self.global_model(x)
                loss = self.ce_loss(logits, y)
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.global_model.parameters(), self.clip_grad)
                self.optimizer.step()
                ce_loss_meter.update(loss.item(), y.shape[0])
                acc = (logits.argmax(dim=1) == y).float().mean().item()
                acc_meter.update(acc, y.shape[0])
        self.scheduler.step()
        self.logger.log({'train_loss': ce_loss_meter.avg}, step=round)
        self.logger.log({'server@train_acc': acc_meter.avg * 100}, step=round)
        self.printer.info(f"server train cost {(time.time() - st) / 60:.2f} min")
    
    @torch.no_grad()
    def test(self, loader):
        self.printer.debug(f'-----------------testing-----------------')
        all_y, all_logits = [], []
        self.global_model.train(False)
        self.global_model.to(self.device)
        for data in loader:   
            x, y = data['x'].to(self.device), data['y'].to(self.device)
            logits = self.global_model(x)
            all_y.append(y)
            all_logits.append(logits)
        y = torch.cat(all_y, dim=0)
        logits = torch.cat(all_logits, dim=0)
        test_acc = accuracy_score(y.cpu().numpy(), logits.argmax(dim=1).cpu().numpy()) * 100
        return test_acc
    
    @torch.no_grad()
    def test_openset(self, loader, round=0):
        self.printer.debug(f'-----------------testing-----------------')
        all_y, all_logits = [], []
        self.global_model.train(False)
        self.global_model.to(self.device)
        for data in loader:   
            x, y = data['x'].to(self.device), data['y'].to(self.device)
            logits = self.global_model(x)
            all_y.append(y)
            all_logits.append(logits)
        y = torch.cat(all_y, dim=0)
        logits = torch.cat(all_logits, dim=0)
        seen_idxs = y < self.num_classes
        close_set_acc = accuracy_score(y[seen_idxs].cpu().numpy(), logits[seen_idxs].argmax(dim=1).cpu().numpy()) * 100
        return 0, close_set_acc
    
    def run(self):
        self.warmup()
        if self.sBN:
            make_batchnorm_stats(self.batchnorm_dataset, self.global_model, self.device)
        for round_idx in range(self.global_rounds):
            st_time = time.time()
            self.selected_clients = self.select_clients(round_idx)
            self.selected_clients.sort()
            lr = self.scheduler.get_last_lr()[0]
            model_dict = self.global_model.state_dict()

            for id in self.selected_clients:
               client = self.clients[id]
               client.train(round_idx, lr, model_dict)
            
            self.aggregate_models(round_idx)
            self.training_stats(round_idx)
            self.train(round_idx)
            cuda.empty_cache() 
            if self.sBN:
                make_batchnorm_stats(self.batchnorm_dataset, self.global_model, self.device)
            self.evaluate(round_idx)
            self.printer.info(f'{round_idx}/{self.global_rounds} cost: {(time.time() - st_time) / 60:.2f} min')
            self.printer.info('-' * 30)
    
    def evaluate(self, round_idx):
        """
        Evaluate the accuracy of current model on each domain testset
        
        """
        st = time.time()
        print_log = 'Evaluation: '
        data = self.dataset['test']
        loader = DataLoader(data, batch_size=256)
        if not self.args.close_test: # open-set test
            open_set_acc, acc = self.test_openset(loader, round=round_idx) 
            print_log += f'openset_acc = {open_set_acc:.2f}% close_set_acc = {acc:.2f}% '
            self.logger.log({"open_set_acc": open_set_acc}, step=round_idx)
        else:
            acc = self.test(loader)
            print_log += f'test_acc = {acc:.2f}%'
        best = False
        if acc > self.best_acc:
            self.best_acc = acc
            best = True
        log_dict = {f'test_acc': acc, f'best_acc': self.best_acc}
        self.logger.log(log_dict, step=round_idx)
        print_log += f'best_acc: {self.best_acc:.2f}% '
        self.save_model(round_idx, best=best)
        print_log += f'cost: {(time.time() - st)/60:.2f} min'
        self.printer.info(print_log)

    def save_model(self, round, best=False):
        if best:
            path = os.path.join(self.save_dir, f'{self.net}_models_best.pth')
            ckpt = {
                'ckpt_model': self.global_model.state_dict(),
            }
            torch.save(ckpt, path)
        if round == self.global_rounds - 1:
            path = os.path.join(self.save_dir, f'{self.net}_models_final.pth')
            ckpt = {
                'ckpt_model': self.global_model.state_dict(),
            }
            torch.save(ckpt, path)


class ClientBase(object):
    def __init__(self, args, id, trainset) -> None:
        self.id = id
        self.args = args
        self.net = args.net
        self.exp_tag = args.exp_tag
        self.load_path = args.load_path
        self.global_rounds = args.global_rounds
        self.local_steps = args.local_steps
        self.trainset = trainset
        self.data_shape = args.data_shape
        self.num_classes = args.num_classes if args.num_classes == args.num_seen_class else args.num_seen_class
        self.batch_size = args.c_batch_size
        self.logger = args.logger
        self.printer = args.printer
        self.threshold = args.threshold
        self.clip_grad = args.clip_grad
        self.mixup = args.mixup
        self.num_samples = len(trainset)
        self.device = torch.device('cuda:0')
        self.model = self.make_model()
        self.optimizer = get_optimizer(self.model, optim_name=args.optim, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        self.optimizer_dict = self.optimizer.state_dict()

    def make_model(self):
        self.printer.debug(f'make model: {self.net}')
        net_builder = get_net_builder(self.net)
        model = net_builder(self.data_shape, self.num_classes).to(self.device)
        return model   

    def set_parameters(self, state_dict):
        self.printer.debug(f'set parameters')
        self.model.load_state_dict(state_dict)
    
    def prepare(self, lr, state_dict):
        self.printer.debug(f'client preparation')
        self.model.to(self.device)
        self.set_parameters(state_dict)
        for group in self.optimizer_dict['param_groups']:
            group['lr'] = lr
        self.optimizer.load_state_dict(self.optimizer_dict)
        self.util = False

    def train(self, round_idx, lr, state_dict):
        self.prepare(lr, state_dict)
        loader = DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True)
        pl_meter, mask_meter, util_meter, ls_meter = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
        pl_acc, mask_acc, util, ls = [], [], [], []
        for step in range(self.local_steps):
            self.model.train(True)
            for i, data in enumerate(loader):
                idx, x, x_s, y = data['idx'], data['x'].to(self.device), data['x_s'].to(self.device), data['y'].to(self.device)
                self.optimizer.zero_grad()
                with torch.no_grad():
                    logits = self.model(x)
                prob, pl = torch.max(logits.softmax(dim=1), dim=1)
                mask = prob.ge(self.threshold)
                logits_s = self.model(x_s)
                ce_loss = F.cross_entropy(logits_s, pl, reduction='none')
                loss = (ce_loss * mask.float()).mean()
                if self.mixup:
                    mix_x, ya, yb, lam = mixup_data(x, pl)
                    logits_m = self.model(mix_x)
                    mixup_loss = lam * F.cross_entropy(logits_m, ya) + (1 - lam) * F.cross_entropy(logits_m, yb)
                    loss += mixup_loss
                loss.backward()
                if self.clip_grad > 0:
                    clip_grad_norm_(self.model.parameters(), self.clip_grad)
                self.optimizer.step()

                pl_meter.update((pl == y).float().mean().cpu().item(), y.shape[0])
                if any(mask):
                    mask_meter.update((pl[mask] == y[mask]).float().mean().cpu().item(), int(mask.sum().cpu().item()))

                ls_meter.update(loss.cpu().item(), y.shape[0])
                util_meter.update(mask.float().mean().cpu().item(), y.shape[0])
            
            pl_acc.append(pl_meter.avg)
            pl_meter.reset() 
            mask_acc.append(mask_meter.avg)
            mask_meter.reset() 
            ls.append(ls_meter.avg)
            ls_meter.reset()
            util.append(util_meter.avg)
            util_meter.reset()

        self.logs = {
            'pl_acc': np.array(pl_acc),
            'mask_acc': np.array(mask_acc),
            'util': np.array(util),
            'loss': np.array(ls),
            'samples': len(self.trainset),
        }
        self.util = self.logs['util'].mean() * self.logs['samples']
        self.optimizer_dict = self.optimizer.state_dict()
        self.model.to('cpu')
