import argparse
import shutil
from datetime import datetime
import copy
from threading import Thread, Lock
from collections import defaultdict
import math
from tqdm import tqdm, trange
import copy

import yaml
from prompt_toolkit import prompt
from tqdm import tqdm

from helper import Helper
from utils.utils import *

torch.autograd.set_detect_anomaly(True)

logger = logging.getLogger('logger')
NUM_SAMPLES = 100

def train(hlpr: Helper, epoch, model, optimizer, train_loader, attack=False, ratio=None, report=False):
    criterion = hlpr.task.criterion
    model.train()

    for i, data in enumerate(train_loader):
        batch = hlpr.task.get_batch(i, data)
        model.zero_grad()
        loss = hlpr.attack.compute_blind_loss(model, criterion, batch, attack, ratio)
        loss.backward()
        optimizer.step()

    return


def test(hlpr: Helper, model, test_loader):
    model.eval()

    metrics = copy.deepcopy(hlpr.task.metrics)
    hlpr.task.reset_metrics(metrics)
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            batch = hlpr.task.get_batch(i, data)
            outputs = model(batch.inputs)
            hlpr.task.accumulate_metrics(metrics, outputs=outputs, labels=batch.labels)
        test_acc, test_loss = hlpr.task.get_metrics(metrics)

    return test_acc, test_loss


def model2vec(network):
    network_params = np.array([])
    for p in network.parameters():
        network_params = np.concatenate([network_params, p.flatten().detach().cpu().numpy()])
    return network_params


def fl_run(hlpr: Helper):
    hlpr.task.model = hlpr.task.build_model()
    curr_sigma = 1
    for epoch in range(hlpr.params.epochs + 1):
        global_model = hlpr.task.model
        grads, utilities, accs, ranks, alphas = [], [], [], [], []
        round_participants = hlpr.task.sample_users_for_round()
        remaining_clients = len(round_participants)
        while remaining_clients > 0:
            thread_pool_size = min(remaining_clients, hlpr.params.max_threads)
            threads = []
            for user in round_participants[len(round_participants) - remaining_clients: \
                                    len(round_participants) - remaining_clients + thread_pool_size]:
                thread = ClientThread(user, hlpr, global_model, r=1, sigma=curr_sigma)
                threads.append(thread)
                thread.start()
            for thread in threads:
                grad, utility, acc, rank, alpha = thread.join()
                grads.append(grad)
                utilities.append(utility)
                accs.append(acc)
                ranks.append(rank)
                alphas.append(alpha)
            remaining_clients -= thread_pool_size

        logger.info(', '.join(map(str, utilities)))
        logger.info(', '.join(map(str, accs)))

        new_state_dict = dict()
        for name, _ in grads[0].items():
            new_state_dict[name] = global_model.state_dict()[name]
        for name in new_state_dict.keys():
            for grad in grads:
                new_state_dict[name].sub_(grad[name] * hlpr.params.lr)

        global_model.load_state_dict(new_state_dict, strict=False)

        logger.warning('Epoch: {}, Beta: {}, Sum of Utilities: {:.3f}'.format(epoch, max(ranks), sum(utilities)))



    


class ClientThread(Thread):
    def __init__(self, user, hlpr, global_model, r, sigma=1):
        super().__init__()
        self.user = user
        self.hlpr = hlpr
        self.model = global_model
        self.r = r
        self.sigma = sigma
        self._return = None

    def run(self):
        criterion = torch.nn.CrossEntropyLoss()
        self.model.train()
        self.model.zero_grad()
        for i, data in enumerate(self.user.train_loader):
            batch = self.hlpr.task.get_batch(i, data)
            logits = self.model(batch.inputs)
            loss = criterion(logits, batch.labels)
            loss.backward()
        
        grad = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                grad[name] = param.grad / len(self.user.train_loader)
        
        # utility
        utility, acc = self.test(self.model)

        self.model.train()
        self.model.zero_grad()
        for i, data in enumerate(self.user.train_loader):
            batch = self.hlpr.task.get_batch(i, data)
            logits = self.model(batch.inputs)
            loss = criterion(logits, batch.labels)
            loss.backward()

        grad = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                grad[name] = param.grad / len(self.user.train_loader)

        d = (a(s + d_s) - a(s)) / d_s
        if not (s + d > c or s + d < 0):
            s = s + delta * d
        self._return = grad, utility, acc, s

    def test(self, model):
        criterion = torch.nn.CrossEntropyLoss()
        loss_all = 0
        correct, total = 0, 0
        model.eval()
        with torch.no_grad():
            for i, data in enumerate(self.user.test_loader):
                batch = self.hlpr.task.get_batch(i, data)
                logits = model(batch.inputs)
                loss = criterion(logits, batch.labels)
                loss_all += loss
                preds = torch.argmax(torch.softmax(logits, dim=1), dim=1)
                correct += int((preds == batch.labels).sum())
                total += batch.labels.size(0)
        loss_all /= len(self.user.train_loader)
        utility = 1 - float(loss_all)
        acc = correct / total

        return utility, acc

    def join(self, *args):
        Thread.join(self, *args)
        return self._return



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--params', dest='params', default='fedavg.yaml')
    parser.add_argument('--name', dest='name', default='test', help='Tensorboard name')
    args = parser.parse_args()

    with open(args.params) as f:
        params = yaml.load(f, Loader=yaml.FullLoader)

    params['current_time'] = datetime.now().strftime('%b.%d_%H.%M.%S')
    params['name'] = args.name

    helper = Helper(params)
    logger.warning(create_table(params))

    try:
        fl_run(helper)
    except (KeyboardInterrupt):
        if helper.params.log:
            answer = prompt('\nDelete the repo? (y/n): ')
            if answer in ['Y', 'y', 'yes']:
                logger.error(f"Fine. Deleted: {helper.params.folder_path}")
                shutil.rmtree(helper.params.folder_path)
                if helper.params.tb:
                    shutil.rmtree(f'runs/{args.name}')
            else:
                logger.error(f"Aborted training. "
                             f"Results: {helper.params.folder_path}. "
                             f"TB graph: {args.name}")
        else:
            logger.error(f"Aborted training. No output generated.")
