from copy import deepcopy
import torch
import numpy as np
import sys

class Benchmark_Local_Train():
    
    def __init__(self, num_nodes, near_center_idx, far_center_idx, bm_type, loss_function, val_size, test_size):
        
        self.num_nodes = num_nodes
        self.near_center_idx = near_center_idx
        self.far_center_idx = far_center_idx

        if bm_type == "near":
            self.center_idx = near_center_idx
        else:
            self.center_idx = far_center_idx

        self.loss_function = loss_function
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.val_size = val_size
        self.test_size = test_size
    
    # initialize dictionaries for storing the results
    def Initialize_Results(self):
        
        # self.metric_dict = { "num_points": [], "top1_train_acc": [], "top2_train_acc": [],
        self.metric_dict = { "num_points": [],
                             "near_val_loss": [], "near_top1_val_acc": [], "near_top2_val_acc": [],
                             "near_test_loss": [], "near_top1_test_acc": [], "near_top2_test_acc": [],
                             "far_val_loss": [], "far_top1_val_acc": [], "far_top2_val_acc": [],
                             "far_test_loss": [], "far_top1_test_acc": [], "far_top2_test_acc": [] }

        self.current_num_points = 0

    # initilaize the model with the same parameter as bi-level
    def Initialize_Variables(self, model, train_loaders, val_loaders, test_loaders):

        self.model = model # theta variable and initialize theta

        self.Evaluation(val_loaders=val_loaders, test_loaders=test_loaders)      
    
    # input a gradient in vector form and return it in list form with model parameter like
    def Grad_Vec_to_List(self, model, grad_vec):
        
        pointer = 0
        grad_list = []
        for model_param in model.parameters():
            num_param = model_param.numel()
            grad_list.append(grad_vec[pointer:pointer+num_param].view_as(model_param).data)
            pointer += num_param
           
        return grad_list
    
    # gradient descent step
    def Gradient_Descent(self, grad_list, lr):
       
        model = self.model
        model_params_generator = model.parameters()
       
        model_indication = deepcopy(model)
        model_indication = model_indication.to(self.device)
       
        with torch.no_grad():
            # each parameter in the model do gradient descent
            for (idx, param) in enumerate(model_indication.parameters()):
                    param = next(model_params_generator)
                    param.data -= lr * grad_list[idx]
    
    # compute the gradient of theta w.r.t. the whole training data (here is validation data)
    def Compute_Full_Gradient_Val(self, val_loader, indication):
        
        # get theta or theta_reference
        if indication == "ref":
            model = self.model_ref
            model.zero_grad()
        else:
            model = self.model
            model.zero_grad()
        
        # get the model parameter generator
        model_params_generator = model.parameters()
        model_params = list(model_params_generator)

        full_dataset_size = 0 # prepare full_batch size for the weighted sum on mini_batches

        for step, batch_data in enumerate(val_loader):        
            
            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            full_dataset_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )

            outputs = model(inputs)            
            loss = self.loss_function(outputs, labels)

            # get mini_batch gradient
            mini_batch_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss, model_params, create_graph=False))
            
           # mini_batch gradient is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch gradient
            if step == 0:
                full_gradient = mini_batch_grad * int(mini_batch_size)   
            else:
                full_gradient += mini_batch_grad * int(mini_batch_size)
            
        # after all mini_batches, re-assign the weight for full_batch gradient
        # full_batch size is the size of training dateset
        full_gradient /= int(full_dataset_size)

        return full_gradient
    
    # SGD for updating theta
    def Update_Model_SGD(self, val_loaders, test_loaders, num_epoch, num_batches, communicate_period, lr_inner):
        # arguments: D^{train}, T_{1}, tau, eta
       
        step_count = 0 # count in t in { 1, 2, ..., T_{1} }
       
        for i in range(num_epoch): # for t = 0, 1, ..., T_{1}-1
                   
            # note that here T_{1} is number of epoches * number of batches
            for step, batch_data in enumerate(val_loaders[self.center_idx]): 
                
                step_count += 1
       
                model = self.model # prepare theta
                model.zero_grad()
                model_params_generator = model.parameters()
                model_params = list(model_params_generator)
                
                inputs, labels = (
                        batch_data[0].to(self.device),
                        batch_data[1].to(self.device),
                    )
                mini_batch_size = batch_data[0].size(dim=0)

                # apply theta
                outputs = model(inputs)
                loss = self.loss_function(outputs, labels)
                    
                # prepare node mini-batch gradient in vector form (if not, will get a tuple)
                # this is sgd-gradient
                mini_batch_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss, model_params, create_graph=False))
                self.current_num_points += mini_batch_size
                
                # transfer the sgd-gradient from vector form to list form with model parameter like
                grad_list = self.Grad_Vec_to_List(model=self.model, grad_vec=mini_batch_grad)

                # perform gradient descent on theta
                self.Gradient_Descent(grad_list=grad_list, lr=lr_inner)

            # perform weighted fed_avg algorithm if conditions are satisfied 
            # always communicate after the last iteration     
            if step_count % (5*communicate_period) == 0 or (i+1)*(step+1) == num_epoch*num_batches:

                self.Evaluation(val_loaders=val_loaders, test_loaders=test_loaders)

    # SVRG for updating theta
    def Update_Model_SVRG(self, val_loaders, test_loaders, num_epoch, num_batches, communicate_period, lr_inner):
        # arguments: D^{val}, T_{1}, tau, eta, q
        self.model_ref = deepcopy(self.model) # initialize theta_reference
        # False means that full gradient of the current model_reference has not been computed
        ref_indicator = "False" 

        step_count = 0

        for i in range(num_epoch): # for t = 0, 1, ..., T_{1}-1

            # note that here T_{1} is number of epoches * number of batches
            for step, batch_data in enumerate(val_loaders[self.center_idx]): 
                
                step_count += 1
       
                model = self.model # prepare theta
                model.zero_grad()
                model_params_generator = model.parameters()
                model_params = list(model_params_generator)
                
                model_ref = self.model_ref # prepare theta_reference
                model_ref.zero_grad()
                model_ref_params_generator = model_ref.parameters()
                model_ref_params = list(model_ref_params_generator)

                inputs, labels = (
                        batch_data[0].to(self.device),
                        batch_data[1].to(self.device),
                    )
                mini_batch_size = batch_data[0].size(dim=0)

                # apply theta
                outputs = model(inputs)
                loss = self.loss_function(outputs, labels)

                # prepare node mini-batch gradient in vector form (if not, will get a tuple)
                mini_batch_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss, model_params, create_graph=False))
                
                # apply theta_reference
                outputs_ref = model_ref(inputs)
                loss_ref = self.loss_function(outputs_ref, labels)
                
                # prepare node mini-batch gradient in vector form (if not, will get a tuple)
                mini_batch_grad_ref = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss_ref, model_ref_params, create_graph=False))
                self.current_num_points += mini_batch_size * 2

                # compute full gradient on theta_reference
                if ref_indicator == "False":

                    full_gradient = self.Compute_Full_Gradient_Val(val_loader=val_loaders[self.center_idx], indication="ref")
                    ref_indicator = "True"

                    self.current_num_points += self.val_size
                
                # compute svrg-gradient
                mini_batch_grad = mini_batch_grad - mini_batch_grad_ref + full_gradient

                # transfer the svrg-gradient from vector form to list form with model parameter like
                grad_list = self.Grad_Vec_to_List(model=self.model, grad_vec=mini_batch_grad)
                
                # update theta_reference with probability q for the node
                # q == 1/num_batches
                prob_sample = np.random.uniform()

                # update theta_reference with probability q
                if prob_sample < 1 / num_batches:
                    self.model_ref = deepcopy(self.model)
                    ref_indicator = "False"
                else:
                    ref_indicator = "True"
                
                # perform gradient descent on theta
                self.Gradient_Descent(grad_list=grad_list, lr=lr_inner)

                if step_count % (5*communicate_period) == 0 or (i+1)*(step+1) == num_epoch*num_batches:
                    self.Evaluation(val_loaders=val_loaders, test_loaders=test_loaders)

    # evaluate validation and loss, and training, validation and testing accuracy 
    # def Evaluation(self, train_loaders, val_loaders, test_loaders):
    def Evaluation(self, val_loaders, test_loaders):

        # get theta which is fixed in this method
        model = self.model
        
        '''
        # training accuracy in bilevel problem
        # this is training accuracy in the inner level of the bilevel problem
        full_train_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
        full_top1_train_acc = 0
        full_top2_train_acc = 0

        for node_idx in range(self.num_nodes):
                
            for step, batch_data in enumerate(train_loaders[node_idx]):        
                
                model.zero_grad()

                mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
                full_train_set_size += int(mini_batch_size)

                inputs, labels = (
                        batch_data[0].to(self.device),
                        batch_data[1].to(self.device),
                    )
                
                outputs = model(inputs)

                # get predicted class labels
                values, top1_predicts = torch.max(outputs, dim=1)
                top1_predicts = top1_predicts.long().to(self.device)

                mini_batch_top1_train_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

                values, top2_predicts = torch.topk(outputs, k=2, dim=1)
                top2_predicts = top2_predicts.long().to(self.device)
                mini_batch_top2_train_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                        .to(self.device).nonzero().reshape(-1).size(dim=0))

                # mini_batch loss is averaged according to this mini_batch, so need to
                # re-assign the weight for mini_batch loss 
                full_top1_train_acc += mini_batch_top1_train_acc 
                full_top2_train_acc += mini_batch_top2_train_acc 

        full_top1_train_acc /= int(full_train_set_size)
        full_top2_train_acc /= int(full_train_set_size)
        self.metric_dict["top1_train_acc"].append(full_top1_train_acc)
        self.metric_dict["top2_train_acc"].append(full_top2_train_acc)
        print("full_top1_train_acc: ", full_top1_train_acc)
        print("full_top2_train_acc: ", full_top2_train_acc)
        print("-" * 30)     
        '''

        # validation loss and validation accuracy in local train
        near_full_val_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
        near_full_val_loss = 0
        near_full_top1_val_acc = 0
        near_full_top2_val_acc = 0

        for step, batch_data in enumerate(val_loaders[self.near_center_idx]):        
            
            model.zero_grad()

            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            near_full_val_set_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )
            
            outputs = model(inputs)

            # get mini_batch loss
            mini_batch_val_loss = self.loss_function(outputs, labels)
            mini_batch_val_loss = mini_batch_val_loss.item()
            
            # get predicted class labels
            values, top1_predicts = torch.max(outputs, dim=1)
            top1_predicts = top1_predicts.long().to(self.device)

            mini_batch_top1_val_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

            values, top2_predicts = torch.topk(outputs, k=2, dim=1)
            top2_predicts = top2_predicts.long().to(self.device)
            mini_batch_top2_val_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                     .to(self.device).nonzero().reshape(-1).size(dim=0))

            # mini_batch loss is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch loss 
            near_full_val_loss += mini_batch_val_loss * int(mini_batch_size)
            near_full_top1_val_acc += mini_batch_top1_val_acc 
            near_full_top2_val_acc += mini_batch_top2_val_acc 

        near_full_val_loss /= int(near_full_val_set_size)
        near_full_top1_val_acc /= int(near_full_val_set_size)
        near_full_top2_val_acc /= int(near_full_val_set_size)
        self.metric_dict["near_val_loss"].append(near_full_val_loss)
        self.metric_dict["near_top1_val_acc"].append(near_full_top1_val_acc)
        self.metric_dict["near_top2_val_acc"].append(near_full_top2_val_acc)
        print("near_full_val_loss: ", near_full_val_loss) 
        print("near_full_top1_val_acc: ", near_full_top1_val_acc)
        print("near_full_top2_val_acc: ", near_full_top2_val_acc)
        print("-" * 30)
        
        far_full_val_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
        far_full_val_loss = 0
        far_full_top1_val_acc = 0
        far_full_top2_val_acc = 0

        for step, batch_data in enumerate(val_loaders[self.far_center_idx]):        
            
            model.zero_grad()

            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            far_full_val_set_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )
            
            outputs = model(inputs)

            # get mini_batch loss
            mini_batch_val_loss = self.loss_function(outputs, labels)
            mini_batch_val_loss = mini_batch_val_loss.item()
            
            # get predicted class labels
            values, top1_predicts = torch.max(outputs, dim=1)
            top1_predicts = top1_predicts.long().to(self.device)

            mini_batch_top1_val_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

            values, top2_predicts = torch.topk(outputs, k=2, dim=1)
            top2_predicts = top2_predicts.long().to(self.device)
            mini_batch_top2_val_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                     .to(self.device).nonzero().reshape(-1).size(dim=0))

            # mini_batch loss is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch loss 
            far_full_val_loss += mini_batch_val_loss * int(mini_batch_size)
            far_full_top1_val_acc += mini_batch_top1_val_acc 
            far_full_top2_val_acc += mini_batch_top2_val_acc 

        far_full_val_loss /= int(far_full_val_set_size)
        far_full_top1_val_acc /= int(far_full_val_set_size)
        far_full_top2_val_acc /= int(far_full_val_set_size)
        self.metric_dict["far_val_loss"].append(far_full_val_loss)
        self.metric_dict["far_top1_val_acc"].append(far_full_top1_val_acc)
        self.metric_dict["far_top2_val_acc"].append(far_full_top2_val_acc)
        print("far_full_val_loss: ", far_full_val_loss) 
        print("far_full_top1_val_acc: ", far_full_top1_val_acc)
        print("far_full_top2_val_acc: ", far_full_top2_val_acc)
        print("-" * 30)

        # testing loss and testing accuracy in bilevel problem
        near_full_test_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
        near_full_test_loss = 0
        near_full_top1_test_acc = 0
        near_full_top2_test_acc = 0

        for step, batch_data in enumerate(test_loaders[self.near_center_idx]):        
            
            model.zero_grad()
            
            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            near_full_test_set_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )
            
            outputs = model(inputs)
            # get mini_batch loss
            mini_batch_test_loss = self.loss_function(outputs, labels)
            mini_batch_test_loss = mini_batch_test_loss.item()
            
            # get predicted class labels
            values, top1_predicts = torch.max(outputs, dim=1)
            top1_predicts = top1_predicts.long().to(self.device)

            mini_batch_top1_test_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

            values, top2_predicts = torch.topk(outputs, k=2, dim=1)
            top2_predicts = top2_predicts.long().to(self.device)
            mini_batch_top2_test_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                      .to(self.device).nonzero().reshape(-1).size(dim=0))

            # mini_batch loss is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch loss 
            near_full_test_loss += mini_batch_test_loss * int(mini_batch_size)
            near_full_top1_test_acc += mini_batch_top1_test_acc 
            near_full_top2_test_acc += mini_batch_top2_test_acc 

        near_full_test_loss /= int(near_full_test_set_size)
        near_full_top1_test_acc /= int(near_full_test_set_size)
        near_full_top2_test_acc /= int(near_full_test_set_size)
        self.metric_dict["near_test_loss"].append(near_full_test_loss)
        self.metric_dict["near_top1_test_acc"].append(near_full_top1_test_acc)
        self.metric_dict["near_top2_test_acc"].append(near_full_top2_test_acc)
        print("near_full_test_loss: ", near_full_test_loss) 
        print("near_full_top1_test_acc: ", near_full_top1_test_acc)
        print("near_full_top2_test_acc: ", near_full_top2_test_acc)
        print("-" * 30)
        
        far_full_test_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
        far_full_test_loss = 0
        far_full_top1_test_acc = 0
        far_full_top2_test_acc = 0

        for step, batch_data in enumerate(test_loaders[self.far_center_idx]):        
            
            model.zero_grad()
            
            mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
            far_full_test_set_size += int(mini_batch_size)

            inputs, labels = (
                    batch_data[0].to(self.device),
                    batch_data[1].to(self.device),
                )
            
            outputs = model(inputs)
            # get mini_batch loss
            mini_batch_test_loss = self.loss_function(outputs, labels)
            mini_batch_test_loss = mini_batch_test_loss.item()
            
            # get predicted class labels
            values, top1_predicts = torch.max(outputs, dim=1)
            top1_predicts = top1_predicts.long().to(self.device)

            mini_batch_top1_test_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

            values, top2_predicts = torch.topk(outputs, k=2, dim=1)
            top2_predicts = top2_predicts.long().to(self.device)
            mini_batch_top2_test_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                      .to(self.device).nonzero().reshape(-1).size(dim=0))

            # mini_batch loss is averaged according to this mini_batch, so need to
            # re-assign the weight for mini_batch loss 
            far_full_test_loss += mini_batch_test_loss * int(mini_batch_size)
            far_full_top1_test_acc += mini_batch_top1_test_acc 
            far_full_top2_test_acc += mini_batch_top2_test_acc 

        far_full_test_loss /= int(far_full_test_set_size)
        far_full_top1_test_acc /= int(far_full_test_set_size)
        far_full_top2_test_acc /= int(far_full_test_set_size)
        self.metric_dict["far_test_loss"].append(far_full_test_loss)
        self.metric_dict["far_top1_test_acc"].append(far_full_top1_test_acc)
        self.metric_dict["far_top2_test_acc"].append(far_full_top2_test_acc)
        print("far_full_test_loss: ", far_full_test_loss) 
        print("far_full_top1_test_acc: ", far_full_top1_test_acc)
        print("far_full_top2_test_acc: ", far_full_top2_test_acc)
        print("-" * 30)
        sys.stdout.flush()

        self.metric_dict["num_points"].append(self.current_num_points)
    
    # return all evaluation results after all stages
    def Output_Results(self):
        
        return self.metric_dict