import logging
import os

import torch
import wandb
from torch import nn
import numpy as np
import time
from thop import profile

from fedml_api.distributed.fednas_extension.utils import adjust_learning_rate
from fedml_api.model.cv.darts import utils
from fedml_api.model.cv.darts.architect import Architect
from fedml_api.model.cv.darts.utils import KL_Loss


class FedNASTrainer(object):
    def __init__(self, client_index,
                 train_data_num, train_data_local_num_dict,
                 train_data_local_dict, test_data_local_dict,
                 model, device, args, teacher_model, local_model):
        self.client_index = client_index
        self.teacher_model = teacher_model
        self.local_model = local_model

        self.train_local = train_data_local_dict[client_index]
        self.test_local = test_data_local_dict[client_index]
        self.local_sample_number = train_data_local_num_dict[client_index]
        self.all_train_data_num = train_data_num

        self.train_data_local_dict = train_data_local_dict
        self.test_data_local_dict = test_data_local_dict
        self.train_data_local_num_dict = train_data_local_num_dict

        self.device = device
        self.args = args
        self.KL_loss = KL_Loss(self.args.temperature)
        self.criterion = nn.CrossEntropyLoss()
        self.model = model
        self.model.to(self.device)
        # self.local_model.to(self.device)
        # training progress
        self.round_index = 0
        #TODO: check value
        self.g_lambda = 0.5

    def update_training_progress(self, round_index):
        self.round_index = round_index

    def update_dataset(self, client_index):
        self.client_index = client_index
        self.train_local = self.train_data_local_dict[client_index]
        self.local_sample_number = self.train_data_local_num_dict[client_index]
        self.test_local = self.test_data_local_dict[client_index]

    def update_model(self, weights):
        logging.info("update_model. client_index = %d" % self.client_index)
        self.model.load_state_dict(weights)

    def update_arch(self, alphas):
        logging.info("update_arch. client_index = %d" % self.client_index)
        for a_g, model_arch in zip(alphas, self.model.arch_parameters()):
            model_arch.data.copy_(a_g.data)

    def update_base_model(self, weights):
        logging.info("update_model. client_index = %d" % self.client_index)
        self.model.load_state_dict(weights, strict=False)

    def get_model_parameters(self):
        arch_parameters = self.model.arch_parameters()
        arch_params = list(map(id, arch_parameters))

        parameters = self.model.parameters()
        weight_params = filter(lambda p: id(p) not in arch_params,
                               parameters)
        return weight_params

    # local search
    def search(self):
        self.model.to(self.device)
        self.model.train()
        if self.teacher_model is not None:
            self.teacher_model.to(self.device)

        arch_parameters = self.model.arch_parameters()
        arch_params = list(map(id, arch_parameters))
        parameters = self.model.parameters()
        weight_params = filter(lambda p: id(p) not in arch_params, parameters)

        # weight optimizer: SGD with momentum + L2 norm
        w_optimizer = torch.optim.SGD(
            weight_params,  # model.parameters(),
            self.args.learning_rate,
            momentum=self.args.momentum,
            weight_decay=self.args.weight_decay)

        # alpha optimizer
        architect_opt = Architect(self.model, self.criterion, self.args, self.device, self.teacher_model)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            w_optimizer, float(self.args.epochs), eta_min=self.args.learning_rate_min)

        local_avg_train_acc = []
        local_avg_train_loss = []
        if self.args.local_finetune == 'False':
            with torch.no_grad():
                acc_local_model_on_local_data, valid_obj, valid_loss = self.local_infer(self.test_local, self.model,
                                                                                        self.criterion)
            logging.info(
                    'Before fine-tunning: client_idx = %d, acc_global_model_on_local_test %f' % (
                    self.client_index, acc_local_model_on_local_data))

        for epoch in range(self.args.epochs):
            # training
            train_acc, train_obj, train_loss = self._local_search(epoch, self.train_local, self.model, architect_opt,
                                                                  self.criterion, w_optimizer)
            logging.info("Back from local Search")
            logging.info('client_idx = %d, epoch = %d, local search_acc %f' % (self.client_index, epoch, train_acc))
            train_loss = train_loss.cpu()
            local_avg_train_acc.append(train_acc)
            local_avg_train_loss.append(train_loss)

            logging.info('client_idx = %d, epoch %d' % (self.client_index, epoch))
            # TODO: trying workshop paper scheduler for weight
            scheduler.step()
            lr = scheduler.get_lr()[0]
            logging.info('client_idx = %d, epoch %d lr %e' % (self.client_index, epoch, lr))
        if self.round_index % self.args.frequency_of_the_test == 0:
            if self.args.local_finetune == 'True':
                with torch.no_grad():
                    acc_local_model_on_local_data, valid_obj, valid_loss = self.local_infer(self.test_local, self.model,
                                                                                            self.criterion)
                logging.info(
                    'After fine-tuning: client_idx = %d, validation accuracy %f' % (self.client_index, acc_local_model_on_local_data))
        else:
            acc_local_model_on_local_data = 0.0
        total_params = 0.0
        total_flops = 0.0
        if self.args.stage == 'fednas_search':
            weights = self.model.cpu().state_dict()
            alphas = self.model.cpu().arch_parameters()
        else:
            raise Exception("abnormal branch")
        # 1e6 for unit conversion, self.args.layers because there 8 cells
        return weights, alphas, self.local_sample_number, \
               sum(local_avg_train_acc) / len(local_avg_train_acc), \
               sum(local_avg_train_loss) / len(local_avg_train_loss), acc_local_model_on_local_data, self.client_index, \
                 (total_flops * self.args.layers / 1e6), (total_params * self.args.layers/ 1e6)

    def _local_search(self, epoch, train_queue, model, architect_opt, criterion, w_optimizer):
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()
        loss = None
        logging.info("Local Search")
        iteration_num = 0
        for step, (input, target) in enumerate(train_queue):
            # logging.info("step = %d" % step)
            iteration_num += 1
            n = input.size(0)

            input = input.to(self.device)
            target = target.to(self.device)

            architect_opt.step_v2(input, target)
            w_optimizer.zero_grad()
            logits = model(input)
            loss = criterion(logits, target)
            loss.backward()
            w_parameters = model.parameters()
            nn.utils.clip_grad_norm_(w_parameters, self.args.grad_clip)
            w_optimizer.step()

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % self.args.report_freq == 0:
                logging.info('epoch = %d, client_index = %d, search %03d %e %f %f', epoch, self.client_index,
                             step, objs.avg, top1.avg, top5.avg)
            if iteration_num == 2 and self.args.is_debug_mode:
                break
        return top1.avg / 100.0, objs.avg / 100.0, loss

    def local_fine_tune(self):
        logging.info("Start local fine-tuning")
        # load the best global model
        self.model.load_state_dict(torch.load(self.args.path_of_best_global_model, map_location=self.device))
        arch_parameters = torch.load(self.args.path_of_best_global_arch_parameter, map_location=self.device)
        for a_g, model_arch in zip(arch_parameters, self.model.arch_parameters()):
            model_arch.data.copy_(a_g.data)
        self.model.to(self.device)
        self.model.train()
        arch_parameters = self.model.arch_parameters()
        arch_params = list(map(id, arch_parameters))
        parameters = self.model.parameters()
        weight_params = filter(lambda p: id(p) not in arch_params,
                               parameters)

        optimizer = torch.optim.SGD(
            weight_params,  # model.parameters(),
            self.args.learning_rate,
            momentum=self.args.momentum,
            weight_decay=self.args.weight_decay)

        # alpha optimizer
        # architect = Architect(self.model, self.criterion, self.args, self.device)
        architect = Architect(self.model, self.criterion, self.args, self.device, self.teacher_model)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, self.args.epochs, eta_min=self.args.learning_rate_min)

        local_avg_train_acc = []
        local_avg_train_loss = []

        # TODO: How is it global model on local data when it is used at the end??
        with torch.no_grad():
            acc_global_model_on_local_data, valid_obj, valid_loss = self.local_infer(self.test_local, self.model,
                                                                                     self.criterion)
        logging.info(
            'Client_idx = %d, acc_global_model_on_local_test %f' % (self.client_index, acc_global_model_on_local_data))

        # record the best
        acc_personalized_model_on_local_data_best = 0.0
        for epoch in range(self.args.epochs_for_local_fine_tuning):
            # training
            train_acc, train_obj, train_loss = self._local_search(epoch, self.train_local,
                                                                  self.model, architect, self.criterion,
                                                                  optimizer)
            logging.info('client_idx = %d, epoch = %d, local search_acc %f' % (self.client_index, epoch, train_acc))
            train_loss = train_loss.cpu()
            local_avg_train_acc.append(train_acc)
            local_avg_train_loss.append(train_loss)

            scheduler.step()
            lr = scheduler.get_last_lr()
            logging.info('client_idx = %d, epoch %d lr %s' % (self.client_index, epoch, str(lr)))

            with torch.no_grad():
                acc_personalized_model_on_local_data, valid_obj, valid_loss = self.local_infer(self.test_local,
                                                                                               self.model,
                                                                                               self.criterion)
            if acc_personalized_model_on_local_data > acc_personalized_model_on_local_data_best:
                acc_personalized_model_on_local_data_best = acc_personalized_model_on_local_data
                # save the best fine-tuned model
                self.save_personal_model(self.args.local_adapted_local_models)
            # save the best
            logging.info('client_idx = %d, acc_personalized_model_on_local_data_best %f' % (
                self.client_index, acc_personalized_model_on_local_data_best))

        weights = self.model.cpu().state_dict()
        alphas = self.model.cpu().arch_parameters()
        pi_parameters = 0.0
        genotype, normal_cnn_count, reduce_cnn_count = self.model.genotype()
        return weights, alphas, self.local_sample_number, \
               sum(local_avg_train_acc) / len(local_avg_train_acc), \
               sum(local_avg_train_loss) / len(local_avg_train_loss), \
               acc_global_model_on_local_data, acc_personalized_model_on_local_data_best, \
               pi_parameters, genotype, self.client_index

    def train(self):
        self.model.to(self.device)
        self.model.train()
        parameters = self.model.parameters()
        optimizer = torch.optim.SGD(
            parameters,  # model.parameters(),
            self.args.learning_rate,
            momentum=self.args.momentum,
            weight_decay=self.args.weight_decay)
        # adding scheduler from workshop paper
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, self.args.epochs, eta_min=self.args.learning_rate_min)

        local_avg_train_acc = []
        local_avg_train_loss = []
        # global model's accuracy on local test data
        if self.args.local_finetune == 'False':
            with torch.no_grad():
                acc_personalized_model_on_local_data, valid_obj, valid_loss = self.local_infer(self.test_local, self.model,
                                                                                         self.criterion)
            logging.info(
                'Before Fine-tunning: client_idx = %d, acc_global_model_on_local_test %f' % (self.client_index, acc_personalized_model_on_local_data))

        for epoch in range(self.args.epochs):
            # training
            train_acc, train_obj, train_loss = self.local_train(self.train_local, self.test_local,
                                                                self.model, self.criterion,
                                                                optimizer)
            # logging.info('client_idx = %d, local train_acc %f' % (self.client_index, train_acc))
            train_loss = train_loss.cpu()
            local_avg_train_acc.append(train_acc)
            local_avg_train_loss.append(train_loss)
            scheduler.step()
            lr = scheduler.get_lr()[0]

        weights = self.model.cpu().state_dict()
        # global model's accuracy on local test data
        if self.round_index % self.args.frequency_of_the_test == 0:
            if self.args.local_finetune == 'True':
                with torch.no_grad():
                    acc_personalized_model_on_local_data, valid_obj, valid_loss = self.local_infer(self.test_local, self.model,
                                                                                             self.criterion)
                logging.info(
                    'After Fine-tunning: client_idx = %d, acc_global_model_on_local_test %f' % (self.client_index, acc_personalized_model_on_local_data))
        else:
            acc_personalized_model_on_local_data = 0.0

        return weights, self.local_sample_number, \
               sum(local_avg_train_acc) / len(local_avg_train_acc), \
               sum(local_avg_train_loss) / len(local_avg_train_loss), \
               acc_personalized_model_on_local_data, self.client_index

    def local_train(self, train_queue, valid_queue, model, criterion, optimizer):
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()

        logging.info("local_train. Number of batches = %d" % len(train_queue))
        iteration_num = 0
        for step, (input, target) in enumerate(train_queue):
            iteration_num += 1
            # logging.info("epoch %d, step %d START" % (epoch, step))
            model.train()
            n = input.size(0)

            input = input.to(self.device)
            target = target.to(self.device)

            optimizer.zero_grad()

            logits = model(input)
            loss = criterion(logits, target)

            loss.backward()
            parameters = model.parameters()
            nn.utils.clip_grad_norm_(parameters, self.args.grad_clip)
            optimizer.step()
            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % self.args.report_freq == 0:
                logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

            if iteration_num == 2 and self.args.is_debug_mode:
                break

        return top1.avg, objs.avg, loss


    def local_infer(self, valid_queue, model, criterion):
        logging.info("local_infer. client_index = %d started." % self.client_index)
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()
        model.to(self.device)

        model_size = np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6
        logging.info("Model size = %F client_index = %d started.", model_size, self.client_index)
        model.eval()
        loss = None
        iteration_num = 0
        start_time = time.time()
        for step, (input, target) in enumerate(valid_queue):
            iteration_num += 1
            input = input.to(self.device)
            target = target.to(self.device)
            logits = model(input)
            loss = criterion(logits, target)

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
            n = input.size(0)

            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % self.args.report_freq == 0:
                logging.info('client_index = %d, valid %03d %e %f %f', self.client_index,
                             step, objs.avg, top1.avg, top5.avg)

            if iteration_num == 2 and self.args.is_debug_mode:
                break
                logging.info('client_index = %d, valid %03d %e %f %f', self.client_index,
                             step, objs.avg, top1.avg, top5.avg)

        end_time = time.time()
        logging.info("Inference time cost: %d" % (end_time - start_time))
        logging.info("local_infer. client_index = %d finished." % self.client_index)
        return top1.avg / 100.0, objs.avg / 100.0, loss

    def global_train(self):
        self.model.to(self.device)
        self.model.train()
        parameters = self.model.parameters()
        optimizer = torch.optim.SGD(
            parameters,  # model.parameters(),
            self.args.learning_rate,
            momentum=self.args.momentum,
            weight_decay=self.args.weight_decay)
        # adding scheduler from workshop paper
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, self.args.epochs, eta_min=self.args.learning_rate_min)

        local_avg_train_acc = []
        local_avg_train_loss = []
        # global model's accuracy on local test data
        if self.args.local_finetune == 'False':
            with torch.no_grad():
                acc_personalized_model_on_local_data, valid_obj, valid_loss = self.local_infer(self.test_local, self.model,
                                                                                         self.criterion)
            logging.info(
                'Before Fine-tunning: client_idx = %d, acc_global_model_on_local_test %f' % (self.client_index, acc_personalized_model_on_local_data))

        for epoch in range(self.args.epochs_for_local_fine_tuning):
            # training
            train_acc, train_obj, train_loss = self.local_train(self.train_local, self.test_local,
                                                                self.model, self.criterion,
                                                                optimizer)
            # logging.info('client_idx = %d, local train_acc %f' % (self.client_index, train_acc))
            train_loss = train_loss.cpu()
            local_avg_train_acc.append(train_acc)
            local_avg_train_loss.append(train_loss)
            scheduler.step()
            lr = scheduler.get_lr()[0]

        weights = self.model.cpu().state_dict()
        # global model's accuracy on local test data
        if self.args.local_finetune == 'True':
            with torch.no_grad():
                global_model_acc_fine_tunned, valid_obj, valid_loss = self.local_infer(self.test_local, self.model,
                                                                                         self.criterion)
            logging.info(
                'After Fine-tunning: client_idx = %d, acc_global_model_on_local_test %f' % (self.client_index, global_model_acc_fine_tunned))
        return global_model_acc_fine_tunned

    # after searching, infer() function is used to infer the searched architecture
    def infer(self):
        self.model.to(self.device)
        self.model.eval()

        test_correct = 0.0
        test_loss = 0.0
        test_sample_number = 0.0
        test_data = self.train_local
        with torch.no_grad():
            for batch_idx, (x, target) in enumerate(test_data):
                x = x.to(self.device)
                target = target.to(self.device)

                pred = self.model(x)
                loss = self.criterion(pred, target)
                _, predicted = torch.max(pred, 1)
                correct = predicted.eq(target).sum()

                test_correct += correct.item()
                test_loss += loss.item() * target.size(0)
                test_sample_number += target.size(0)
            logging.info("client_idx = %d, local_train_loss = %s" % (self.client_index, test_loss))
        return test_correct / test_sample_number, test_loss

    def save_personal_model(self, path):
        my_dir_model_params = path + '/model_params_client_number' + str(
            self.client_index) + '.pth'
        if os.path.exists(my_dir_model_params):  # checking if there is a file with this name
            os.remove(my_dir_model_params)  # deleting the file
        torch.save(self.model.cpu().state_dict(), my_dir_model_params)  # save the model
        logging.info(" Personal Model of Client number %d saved " % self.client_index)

        my_dir_arch_params = path + '/arch_params_client_number' + str(
            self.client_index) + '.pth'
        if os.path.exists(my_dir_arch_params):  # checking if there is a file with this name
            os.remove(my_dir_arch_params)  # deleting the file
        torch.save(self.model.cpu().arch_parameters(), my_dir_arch_params)  # save the model
        logging.info(" Personal architecture of Client number %d saved " % self.client_index)

    def save_per_train_model(self):
        my_dir_model_params = self.args.path_of_local_model + '/model_params_client_number' + str(
            self.client_index) + '.pth'
        if os.path.exists(my_dir_model_params):  # checking if there is a file with this name
            os.remove(my_dir_model_params)  # deleting the file
        torch.save(self.local_model.cpu().state_dict(), my_dir_model_params)  # save the model
        logging.info(" Personal Model of Client number %d saved " % self.client_index)


    def load_per_train_model(self):
        my_dir_model_params = self.args.path_of_local_model + '/model_params_client_number' + str(
            self.client_index) + '.pth'
        if os.path.exists(my_dir_model_params):  # checking if there is a file with this name
            os.remove(my_dir_model_params)  # deleting the file
        torch.save(self.local_model.cpu().state_dict(), my_dir_model_params)  # save the model
        logging.info(" Personal Model of Client number %d saved " % self.client_index)

    def load_personal_model(self, path):
        my_dir_model_params = path + '/model_params_client_number' + str(
            self.client_index) + '.pth'
        if os.path.exists(my_dir_model_params):  # checking if there is a file with this name
            self.model.load_state_dict(torch.load(my_dir_model_params))  # if yes load it
            logging.info(" Personal Model of Client number %d Loaded " % self.client_index)
        else:
            logging.info(" Personal model of client number %d does not exist " % self.client_index)

        my_dir_arch_params = path + '/arch_params_client_number' + str(
            self.client_index) + '.pth'
        if os.path.exists(my_dir_arch_params):  # checking if there is a file with this name
            arch_params = torch.load(my_dir_arch_params)
            for a_g, model_arch in zip(arch_params, self.model.arch_parameters()):
                model_arch.data.copy_(a_g.data)
            logging.info(" Personal Architecture of Client number %d Loaded " % self.client_index)
        else:
            logging.info(" Personal model of client number %d does not exist " % self.client_index)

