import logging
import os

import torch
import wandb
from torch import nn
import numpy as np
import time
from thop import profile

from fedml_api.model.cv.darts import utils
from fedml_api.model.cv.darts.model_search_workshop_code import ModelForModelSizeMeasure
from fedml_api.model.cv.darts.model_search_workshop_code import Network

class FedNAS_Local_Train():
    def __init__(self, args, train_data_local_dict, test_data_local_dict, device):
        self.args = args
        self.client_num = args.client_num_in_total
        self.args.path_of_local_model = args.path_of_local_model
        self.epoch = args.epochs_for_train
        self.lr = args.lr
        self._layers = args.layers
        self.device = device
        self._criterion = nn.CrossEntropyLoss()
        self._C = 16
        self._num_classes = 10
        self.train_data_local_dict = train_data_local_dict

        self.test_data_local_dict = test_data_local_dict
        self.args.momentum = args.momentum
        self.valid_acc_dict = dict()
        self.train_local = self.train_data_local_dict[0]
        self.test_local = self.test_data_local_dict[0]

    def local_train_and_infer(self):
        for client_idx in range(self.client_num):
            #Create a model and load alphas and w
            ##steps=4, multiplier=4, stem_multiplier=3
            alphas_normal, alphas_reduce = self.load_locally_adaptive_model(client_idx)
            #
            # ( C, num_classes, layers, criterion, device, args, steps=4, multiplier=4,
            #              stem_multiplier=3)
            model = ModelForModelSizeMeasure(self._C, self._num_classes, self._layers, self._criterion,
                                         alphas_normal, alphas_reduce)
            # model = Network(self._C, self._num_classes, self._layers, self._criterion, self.device,
            #                              self.args)
            self.load_locally_adaptive_weights(client_idx, model)

            # optimizer
            parameters = model.parameters()
            optimizer = torch.optim.SGD(
                parameters,  # model.parameters(),
                self.args.learning_rate,
                momentum=self.args.momentum,
                weight_decay=self.args.weight_decay)
            criterion = nn.CrossEntropyLoss()
            train_queue, valid_queue = self.update_dataset(client_idx)
            self.train(train_queue, model, criterion, optimizer, self.device)

            with torch.no_grad():
                self.valid_acc_dict[client_idx] = self.local_infer(client_idx, valid_queue, model, criterion, self.device)
            del model

        logging.info("Model returned")
        return self.valid_acc_dict

    def load_locally_adaptive_model(self, client_index):
        my_dir_arch_params = self.args.path_of_local_model + '/arch_params_client_number' + str(
            client_index) + '.pth'
        if os.path.exists(my_dir_arch_params):  # checking if there is a file with this name
            arch_params = torch.load(my_dir_arch_params)
            logging.info(" Personal Architecture of Client number %d Loaded " % client_index)
            return arch_params[0], arch_params[1]
        else:
            raise Exception(" Personal model of client number %d does not exist " % client_index)

    def load_locally_adaptive_weights(self, client_index, model):
        my_dir_model_params = self.args.path_of_local_model + '/model_params_client_number' + str(
            client_index) + '.pth'
        if os.path.exists(my_dir_model_params):  # checking if there is a file with this name
            model.load_state_dict(torch.load(my_dir_model_params), strict = False)  # load only for branches that exist
            logging.info(" Personal Model of Client number %d Loaded " % client_index)
        else:
            logging.info(" Personal model of client number %d does not exist " % client_index)

    def update_dataset(self, client_index):
        self.client_index = client_index
        self.train_local = self.train_data_local_dict[client_index]
        # self.local_sample_number = self.train_data_local_num_dict[client_index]
        self.test_local = self.test_data_local_dict[client_index]
        return self.train_local, self.test_local

    def train(self, train_local, model, criterion, optimizer, device):
        # Train it for 1 epoch
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()

        logging.info("local_train. Number of batches = %d" % len(train_local))
        iteration_num = 0

        for step, (input, target) in enumerate(self.train_local):
            iteration_num += 1
            # logging.info("epoch %d, step %d START" % (epoch, step))
            model.to(device)
            model.train()
            n = input.size(0)

            input = input.to(device)
            target = target.to(device)

            optimizer.zero_grad()

            logits = model(input)
            loss = criterion(logits, target)
            loss.backward()
            parameters = model.parameters()
            nn.utils.clip_grad_norm_(parameters, self.args.grad_clip)
            optimizer.step()
            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % self.args.report_freq == 0:
                logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

            if iteration_num == 2 and self.args.is_debug_mode:
                break

    def local_infer(self, client_index, valid_queue, model, criterion, device):
        logging.info("local_infer. client_index = %d started." % client_index)
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()
        model.to(device)

        model_size = np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6
        logging.info("Model size = %F client_index = %d started.", model_size, self.client_index)
        model.eval()
        loss = None
        iteration_num = 0
        start_time = time.time()
        model.to(device)
        for step, (input, target) in enumerate(valid_queue):
            iteration_num += 1
            input = input.to(device)
            target = target.to(device)
            logits = model(input)
            loss = criterion(logits, target)

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
            n = input.size(0)

            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % self.args.report_freq == 0:
                logging.info('client_index = %d, valid %03d %e %f %f', client_index,
                             step, objs.avg, top1.avg, top5.avg)

            if iteration_num == 2 and self.args.is_debug_mode:
                logging.info('client_index = %d, valid %03d %e %f %f', client_index,
                             step, objs.avg, top1.avg, top5.avg)
                return (top1.avg / 100.0)
                break

        end_time = time.time()
        logging.info("Inference time cost: %d" % (end_time - start_time))
        logging.info("local_infer. client_index = %d finished." % client_index)
        return (top1.avg / 100.0)

        # return top1.avg / 100.0, objs.avg / 100.0, loss






