import argparse
from cgi import print_form
import json
import os
from re import M
import time
import itertools
import pickle
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from datasets import load_metric
from torch.utils.data import DataLoader, Dataset, TensorDataset
import math
from torch.utils.tensorboard import SummaryWriter

import numpy as np
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

from train_cluster_mnist import  TrainMNISTCluster,SimpleLinear

from util import *


# LR_DECAY = True
LR_DECAY = False

def main():

    config = get_config()
    print("config:",config)
    config['train_seed'] = config['data_seed']

    exp = TrainMNISTClusterAFL(config)
    exp.setup()
    if config['inference_mode']==0:
        exp.run()
    else:
        print("inference mode")
        exp.inference()


def get_config():
    parser = argparse.ArgumentParser()
    # parser.add_argument("--project-dir",type=str,default="output")
    parser.add_argument("--project-dir","--projdir" , type=str,default="output/qffl")
    # parser.add_argument("--dataset-dir",type=str,default="output")
    # parser.add_argument("--num-epochs",type=float,default=)
   
    parser.add_argument("--lr",type=float,default=5e-5)
    # parser.add_argument("--lr_server",type=float,default=0.1)
    parser.add_argument("--learning_rate_lambda",type=float,default=5e-6)
    parser.add_argument("--train-seed",type=int,default=0)
    parser.add_argument("--config-override",type=str,default="")

    parser.add_argument("--inference-mode",type=int,default=0)
    parser.add_argument("--checkpoint",type=int,default=0)

    args = parser.parse_args()
    # read config json and update the sysarg
    with open("config_afl.json", "r") as read_file:
        config = json.load(read_file)

    args_dict = vars(args)
    config.update(args_dict)

    if config["config_override"] == "":
        del config['config_override']
    else:
        print(config['config_override'])
        config_override = json.loads(config['config_override'])
        del config['config_override']
        config.update(config_override)

    return config


class TrainMNISTClusterAFL(TrainMNISTCluster):
    def setup(self):
        self.epoch = -1 
        os.makedirs(self.config['project_dir'], exist_ok = True)

        self.result_fname = os.path.join(self.config['project_dir'], 'results_single.pickle')
        self.checkpoint_fname = os.path.join(self.config['project_dir'], 'checkpoint_single.pt')
        if self.config['checkpoint']:
            self.load_checkpoint_fname = os.path.join(self.config['project_dir'], 'checkpoint_single.pt')
        
        self.writer = SummaryWriter(os.path.join(self.config['project_dir'],'runs'))
        self.setup_datasets()
        self.setup_models()

        self.learning_rate_lambda = self.config['learning_rate_lambda']

        self.dynamic_lambdas = torch.ones(self.config['m']) * 1.0 / self.config['m']
        # self.lr_server = self.config['lr_server']

        self.lr = None
        self.round = 0


    def setup_models(self):
        np.random.seed(self.config['train_seed'])
        torch.manual_seed(self.config['train_seed'])

        p = self.config['p']

        self.models = [AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) for p_i in range(p)] # m models
        metric = load_metric("accuracy")

        def compute_metrics(eval_pred):
            logits, labels = eval_pred
            predictions = np.argmax(logits, axis=-1)
            return metric.compute(predictions=predictions, references=labels)

        self.criterion = compute_metrics
        self.agg_criterion = nn.NLLLoss()

        if self.config['checkpoint']:
            print('loading checkpoint from',self.load_checkpoint_fname )

            states = torch.load(self.load_checkpoint_fname)

            state_dict = states['models'][0]
            model = self.models[0]
            model.load_state_dict(state_dict)
            self.epoch =  states['epoch']

        self.result_model = copy.deepcopy(self.models)


    def inference(self):
        cfg = self.config
        m = cfg['m_test']
        num_data = 0
        losses = []
        corrects = []
        num_samples =[]
        accuracys = []
        for m_i in range(m):
            dataset = self.load_data(m_i, train=False)  # load batch data rotated

            training_args = TrainingArguments(output_dir=f"inference_client_{m_i}",
                                              num_train_epochs=1,
                                              per_device_train_batch_size=32,
                                              per_device_eval_batch_size=32,
                                              local_rank=-1
                                              )
            trainer = Trainer(
                model=self.result_model[0],
                args=training_args,
                eval_dataset=dataset,
                compute_metrics=self.criterion,
            )
            eval_log = trainer.evaluate()

            print(eval_log)
            print(len(dataset))

            prediction = trainer.predict(dataset)
            print(prediction)
            y_logit = torch.tensor(prediction.predictions).cuda()
            print(y_logit.shape, type(y_logit))

            losses.append(eval_log['eval_loss'])
            corrects.append(eval_log['eval_accuracy'] * len(dataset))

            num_samples.append(len(dataset))
            accuracys.append(eval_log['eval_accuracy'])
            num_data += len(dataset)
        
        loss = np.mean(losses)
        acc = np.sum(corrects) / num_data
        print("loss", loss, "acc",  acc)

        inference_stats ={'losses':losses, 'corrects':corrects , 'num_samples':num_samples,  'accuracys':accuracys}
        print(inference_stats)
        with open(os.path.join(self.config['project_dir'], 'inference_stats.json'),'w') as out:
            out.write(json.dumps(inference_stats))




    def get_inference_stats(self, train = True):
        cfg = self.config
        if train:
            m = cfg['m']
            dataset = self.dataset['train']
        else:
            m = cfg['m_test']
            dataset = self.dataset['test']

        p = cfg['p']


        num_data = 0
        losses = {}
        corrects = {}
        for m_i in range(m):
            eval_dataset = self.load_data(m_i, train=train) # load batch data rotated
            print(len(eval_dataset))

            p_i=0
            # for p_i in range(p):
            training_args = TrainingArguments(output_dir=f"get_stat_client_{m_i}_{p_i}",
                                              num_train_epochs=1,
                                              per_device_train_batch_size=32,
                                              per_device_eval_batch_size=32,
                                              local_rank=-1,
                                              )
            trainer = Trainer(
                model=self.result_model[0],
                args=training_args,
                train_dataset=eval_dataset,
                eval_dataset=eval_dataset,
                compute_metrics=self.criterion,
            )
            eval_log = trainer.evaluate()
            print(eval_log)

            losses[(m_i,p_i)] = eval_log['eval_loss']
            corrects[(m_i,p_i)] = eval_log['eval_accuracy'] * len(dataset)

            num_data += len(dataset)

        # calculate loss and cluster the machines
        cluster_assign = []
        for m_i in range(m):
            min_p_i = 0 ########### force this #############
           
            cluster_assign.append(min_p_i)

        # calculate optimal model's loss, acc over all models
        min_corrects = []
        min_losses = []
        for m_i, p_i in enumerate(cluster_assign): # iter over clients, each client use 0 cluster  assignment 

            min_loss = losses[(m_i,p_i)]
            min_losses.append(min_loss)

            min_correct = corrects[(m_i,p_i)]
            min_corrects.append(min_correct)


        loss = np.mean(min_losses)
        acc = np.sum(min_corrects) / num_data


        # check cluster assignment acc
        cl_acc = np.mean(np.array(cluster_assign) == np.array(dataset['cluster_assign']))
        cl_ct = [np.sum(np.array(cluster_assign) == p_i ) for p_i in range(p)]

        res = {} # results
        # res['losses'] = losses
        # res['corrects'] = corrects
        res['cluster_assign'] = cluster_assign
        res['num_data'] = num_data
        res['loss'] = loss
        res['acc'] = acc
        res['cl_acc'] = cl_acc
        res['cl_ct'] = cl_ct
        res['is_train'] = train

        # import ipdb; ipdb.set_trace()

        return res

    def train(self, cluster_assign, lr):
   
        cfg = self.config
        m = cfg['m']
        p = cfg['p']
        tau = cfg['tau']
        # run local update
        print("train qffl!! cluster_assign",cluster_assign)

        updated_models = []
        client_loss= []
        for m_i in range(m):
            dataset = self.load_data(m_i)
            p_i = cluster_assign[m_i]
            model = copy.deepcopy(self.models[p_i])
            local_train_loss =0 
            for step_i in range(tau):
                training_args = TrainingArguments(output_dir=f"train_client_{m_i}",
                                                  num_train_epochs=1,
                                                  per_device_train_batch_size=32,
                                                  per_device_eval_batch_size=32,
                                                  local_rank=-1,
                                                  learning_rate=lr
                                                  )
                trainer = Trainer(
                    model=model,
                    args=training_args,
                    train_dataset=dataset,
                    compute_metrics=self.criterion,
                )
                train_log = trainer.train()
                print(train_log)
                local_train_loss += train_log.training_loss

            model.zero_grad()
            updated_models.append(model)
            local_train_loss = local_train_loss/ float(tau) + 1e-10
            client_loss.append(local_train_loss)

        lr_server = lr 
        # apply gradient update
        local_models = [[] for p_i in range(p)]
        for m_i in range(m):
            p_i = cluster_assign[m_i]
            local_models[p_i].append(updated_models[m_i])

        for p_i, models in enumerate(local_models): # should only have one cluster 
            if len(models) >0:
                
                global_model = self.models[p_i] 
                initial_state_dict = global_model.state_dict()
                for m_i, local_model in enumerate(models): 
                    self.differentiate_learner(
                        target=local_model,
                        reference_state_dict=initial_state_dict,
                        coeff=1  / lr_server
                    )

                self.global_grad_update_afl(models, self.models[p_i], self.dynamic_lambdas, lr_server) # p_i is the cluser_id
                self.dynamic_lambdas = [lmb_i + self.learning_rate_lambda * loss_i for lmb_i, loss_i in
                                zip(self.dynamic_lambdas, client_loss)]
                self.dynamic_lambdas = self.project(self.dynamic_lambdas)


                for name, param in self.result_model[p_i].named_parameters():
                    param.data  = (param.data.cuda()  * self.round + global_model.state_dict()[name].data.cuda() * 1) /( self.round+1)

        self.round+=1

    def project(self, p):
        p = [ p_i.detach().numpy() for p_i in p]
        
        u = sorted(p, reverse=True)
        res = []
        rho = 0
        for i in range(len(p)):
            if (u[i] + (1.0 / (i + 1)) * (1 - np.sum(np.asarray(u)[:i + 1]))) > 0:
                rho = i + 1
        lmbd = (1.0 / rho) * (1 - np.sum(np.asarray(u)[:rho]))
        for i in range(len(p)):
            res.append(max(p[i] + lmbd, 0))
        res =  torch.from_numpy(np.array(res)) 
        return res


    def global_grad_update_afl(self, local_models, global_model, aggregation_weights, lr):
        # average of each weight
        aggregation_weights =  aggregation_weights/torch.sum(aggregation_weights)
        print("aggregation_weights", aggregation_weights)
        grads = {}
        for m_i, local_model in enumerate(local_models):
            for name, param in local_model.named_parameters():
                if name not in grads:
                    grads[name] = torch.zeros_like(param.data)
                
                if param.grad is not None:
                    grads[name] += param.grad.clone() * aggregation_weights[m_i]

        for name, param in global_model.named_parameters():
            # weights[name] /= len(local_models)
            if param.requires_grad:
                param.data  = param.data  - grads[name].cpu() * lr




    def differentiate_learner(self, target, reference_state_dict, coeff=1.):
        """
        set the gradient of the model to be the difference between `target` and `reference` multiplied by `coeff`

        :param target:
        :type target: Learner
        :param reference_state_dict:
        :type reference_state_dict: OrderedDict[str, Tensor]
        :param coeff: default is 1.
        :type: float

        """
        target_state_dict = target.state_dict(keep_vars=True)

        for key in target_state_dict:
            if target_state_dict[key].data.dtype == torch.float32:

                target_state_dict[key].grad = \
                    coeff * (reference_state_dict[key].data.clone().cuda() - target_state_dict[key].data.clone().cuda())

    def save_checkpoint(self, epoch =0 ):
        models_to_save = [model.state_dict() for model in self.result_model]
        torch.save({'models': models_to_save, 'epoch': epoch}, self.checkpoint_fname)

        models_to_save = [model.state_dict() for model in self.models]
        torch.save({'models':models_to_save, 'epoch': epoch}, os.path.join(self.config['project_dir'], 'checkpoint_trueglobal.pt'))


if __name__ == '__main__':
    start_time = time.time()
    main()
    duration = (time.time() - start_time)
    print("---train cluster single Ended in %0.2f hour (%.3f sec) " % (duration/float(3600), duration))