import argparse
import json
import os
import time
import itertools
import pickle
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset, TensorDataset

from torch.utils.tensorboard import SummaryWriter
import numpy as np
from transformers import AutoTokenizer, TrainingArguments, Trainer

from util import *


# LR_DECAY = True
LR_DECAY = False

def main():

    config = get_config()

    config['train_seed'] = config['data_seed']

    print("config:",config)

    exp = TrainMNISTCluster(config)
    exp.setup()
    exp.run()


def get_config():

    parser = argparse.ArgumentParser()
    parser.add_argument("--project-dir","--projdir" , type=str,default="output/ifca")
    # parser.add_argument("--dataset-dir",type=str,default="output")
    # parser.add_argument("--num-epochs",type=float,default=)
    parser.add_argument("--inference-mode",type=int,default=0)
    parser.add_argument("--lr",type=float,default=0.1)
    parser.add_argument("--data-seed",type=int,default=0)
    parser.add_argument("--train-seed",type=int,default=0)
    parser.add_argument("--config-override",type=str,default="")
    args = parser.parse_args()

    # read config json and update the sysarg
    with open("config.json", "r") as read_file:
        config = json.load(read_file)

    args_dict = vars(args)
    config.update(args_dict)


    if config["config_override"] == "":
        del config['config_override']
    else:
        print(config['config_override'])
        config_override = json.loads(config['config_override'])
        del config['config_override']
        config.update(config_override)

    return config


class TrainMNISTCluster(object):
    def __init__(self, config):
        self.config = config

        # assert self.config['m'] % self.config['p'] == 0

    def setup(self):
        self.epoch = -1 
        os.makedirs(self.config['project_dir'], exist_ok = True)

        self.result_fname = os.path.join(self.config['project_dir'], 'results.pickle')
        self.checkpoint_fname = os.path.join(self.config['project_dir'], 'checkpoint.pt')
        self.writer = SummaryWriter(os.path.join(self.config['project_dir'],'runs'))
        self.setup_datasets()
        self.setup_models()
        self.lr = None


    def setup_datasets(self):

        np.random.seed(self.config['data_seed'])

        # generate indices for each dataset
        # also write cluster info

        np.random.seed(self.config['data_seed'])

        cfg = self.config

        self.dataset = {}

        dataset = {}
        # (sst2, X1, y1) = self._load_sst2(train=True)
        (sst2, X1, y1) = self._load_yelp(train=True)
        (imdb, X2, y2) = self._load_imdb(train=True)
        classes2indx1 = self.build_class_dict(y1)
        classes2indx2 = self.build_class_dict(y2)

        dataset['data_indices'], dataset['cluster_assign'] = \
            self._setup_dataset(cfg['p'], cfg['m'],classes2indx1, classes2indx2)

        dataset['dataset'] = (sst2, imdb)
        dataset['X'] = (X1, X2)
        dataset['y'] = (y1, y2)
        self.dataset['train'] = dataset

        dataset = {}
        # (sst2, X1, y1) = self._load_sst2(train=False)
        (sst2, X1, y1) = self._load_yelp(train=False)
        (imdb, X2, y2) = self._load_imdb(train=False)
        classes2indx1 = self.build_class_dict(y1)
        classes2indx2 = self.build_class_dict(y2)
        dataset['data_indices'], dataset['cluster_assign'] = \
            self._setup_dataset(cfg['p'], cfg['m'],classes2indx1, classes2indx2, random=False)
        dataset['dataset'] = (sst2, imdb)
        dataset['X'] = (X1, X2)
        dataset['y'] = (y1, y2)
        self.dataset['test'] = dataset

      

    def build_class_dict (self, labels ):
        classes2indx = {}
        for ind, label in enumerate(labels):  # for cifar: 50000; for tinyimagenet: 100000
            label =int(label)
            if label in classes2indx:
                classes2indx[label].append(ind)
            else:
                classes2indx[label] = [ind]
        return classes2indx


    def _setup_dataset(self, p, m, classes2indx1, classes2indx2, random = True):
        cfg = self.config
        data_indices = []
        cluster_assign = []


        m1 = len(cfg['per_cluster_list'][0])
        ll21 = [[] for _ in range(m1)]
        # for class balance
        for key in classes2indx1.keys():
            if random:
                per_class_ll = list(np.random.permutation(classes2indx1[key]))
            else:
                per_class_ll = classes2indx1[key]
            per_class_ll2 = chunkify(per_class_ll, m1)
            # print("client:" , 0, "class: ", key, "len:", len(per_class_ll2[0]) )
            ll21 = [ll21[index]+per_class_ll2[index] for index in range(m1)]

        m2 = len(cfg['per_cluster_list'][1])
        ll22 = [[] for _ in range(m2)]
        # for class balance
        for key in classes2indx2.keys():
            if random:
                per_class_ll = list(np.random.permutation(classes2indx2[key]))
            else:
                per_class_ll = classes2indx2[key]
            per_class_ll2 = chunkify(per_class_ll, m2)
            # print("client:" , 0, "class: ", key, "len:", len(per_class_ll2[0]) )
            ll22 = [ll22[index] + per_class_ll2[index] for index in range(m2)]

        ll2 = ll21 + ll22

        for p_i in range(p):
            per_cluster =cfg['per_cluster_list'][p_i] # pre-defined groundtruth client-cluster assignment
            print("client idx for cluster {}: {}".format(p_i, per_cluster))
            for m_i in per_cluster:
                data_indices += [ll2[m_i]]
            cluster_assign += [p_i for _ in per_cluster]

        data_indices = np.array(data_indices)
        cluster_assign = np.array(cluster_assign)
        assert data_indices.shape[0] == cluster_assign.shape[0]
        assert data_indices.shape[0] == m

        return data_indices, cluster_assign
        


    def _load_MNIST(self, train=True):
        transforms = torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               # torchvision.transforms.Normalize(
                               #   (0.1307,), (0.3081,))
                             ])
        if train:
            mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms)
        else:
            mnist_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms)

        dl = DataLoader(mnist_dataset) # if we sample batch from dataloader, it will be automatically be 0~1 because "torchvision.transforms.ToTensor()"; but if we use dataloader.dataset.data, it is still 0~255

        X = dl.dataset.data # (60000,28, 28)
        y = dl.dataset.targets #(60000)

        # normalize to have 0 ~ 1 range in each pixel
        X = X / 255.0

        return X, y

    def _load_imdb(self, train=True):
        from datasets import load_dataset
        dataset = load_dataset("imdb")
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        def tokenize_function(examples):
            return tokenizer(examples["text"], padding="max_length", truncation=True)

        tokenized_datasets = dataset.map(tokenize_function, batched=True)

        if train:
            imdb_dataset = tokenized_datasets['train']
            X = load_dataset("imdb")['train']['text']
            y = load_dataset("imdb")['train']['label']
        else:
            imdb_dataset = tokenized_datasets['test']
            X = load_dataset("imdb")['test']['text']
            y = load_dataset("imdb")['test']['label']

        return imdb_dataset,X, y

    def _load_sst2(self, train=True):
        from datasets import load_dataset
        dataset = load_dataset("glue", "sst2")
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        def tokenize_function(examples):
            return tokenizer(examples["sentence"], padding="max_length", truncation=True)

        tokenized_datasets = dataset.map(tokenize_function, batched=True)

        if train:
            imdb_dataset = tokenized_datasets['train']
            X = load_dataset("glue", "sst2")['train']['sentence']
            y = load_dataset("glue", "sst2")['train']['label']
        else:
            imdb_dataset = tokenized_datasets['validation']
            X = load_dataset("glue", "sst2")['validation']['sentence']
            y = load_dataset("glue", "sst2")['validation']['label']

        return imdb_dataset, X, y


    def _load_yelp(self, train=True):
        from datasets import load_dataset
        dataset = load_dataset("yelp_polarity")
        dataset = dataset.shuffle(seed=42)
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        def tokenize_function(examples):
            return tokenizer(examples["text"], padding="max_length", truncation=True)

        if train:
            dataset = dataset['train'].select(range(56000))
            tokenized_datasets = dataset.map(tokenize_function, batched=True)
            imdb_dataset = tokenized_datasets
            X = imdb_dataset['text']
            y = imdb_dataset['label']
        else:
            dataset = dataset['test']
            tokenized_datasets = dataset.map(tokenize_function, batched=True)
            imdb_dataset = tokenized_datasets
            X = imdb_dataset['text']
            y = imdb_dataset['label']

        return imdb_dataset, X, y


    def setup_models(self):
        np.random.seed(self.config['train_seed'])
        torch.manual_seed(self.config['train_seed'])

        p = self.config['p']

        self.models = [ SimpleLinear(h1 = self.config['h1']) for p_i in range(p)]

        self.criterion = torch.nn.CrossEntropyLoss()

        # import ipdb; ipdb.set_trace()


    def run(self):
        num_epochs = self.config['num_epochs']
        lr = self.config['lr']

        results = []

      
    

        result = {}
        result['epoch'] =self.epoch

        t0 = time.time()
        res = self.test(train=True)
        t1 = time.time()
        res['infer_time'] = t1-t0
        result['train'] = res

        self.print_epoch_stats(res)

        t0 = time.time()
        res = self.test(train=False)
        t1 = time.time()
        res['infer_time'] = t1-t0
        result['test'] = res
        self.print_epoch_stats(res)
        results.append(result)

        # this will be used in next epoch
        cluster_assign = result['train']['cluster_assign']
        
        start_epoch = self.epoch + 1 

        for epoch in range(start_epoch, num_epochs):
            self.epoch = epoch

            result = {}
            result['epoch'] = epoch

            lr = self.lr_schedule(epoch)
            result['lr'] = lr

            t0 = time.time()
            result['train'] = self.train(cluster_assign, lr = lr)
            t1 = time.time()
            train_time = t1-t0

            t0 = time.time()
            res = self.test(train=True)
            t1 = time.time()
            res['infer_time'] = t1-t0
            res['train_time'] = train_time
            res['lr'] = lr
            result['train'] = res

            self.print_epoch_stats(res)

            t0 = time.time()
            res = self.test(train=False)
            t1 = time.time()
            res['infer_time'] = t1-t0
            result['test'] = res
            self.print_epoch_stats(res)

            results.append(result)

            # this will be used in next epoch's gradient update
            cluster_assign = result['train']['cluster_assign']

            if epoch % 10 == 0 or epoch == num_epochs - 1 :
                with open(self.result_fname, 'wb') as outfile:
                    pickle.dump(results, outfile)
                    print(f'result written at {self.result_fname}')
                self.save_checkpoint(epoch)
                print(f'checkpoint written at {self.checkpoint_fname}')

        # import ipdb; ipdb.set_trace()
        self.writer.close()


    def lr_schedule(self, epoch):
        if self.lr is None:
            self.lr = self.config['lr']

        if epoch % 50 == 0 and epoch != 0 and LR_DECAY:
            self.lr = self.lr * 0.1

        return self.lr


    def print_epoch_stats(self, res):
        if res['is_train']:
            data_str = 'tr'
        else:
            data_str = 'tst'

        if 'train_time' in res:
            time_str = f"{res['train_time']:.3f}sec(train) {res['infer_time']:.3f}sec(infer)"
        else:
            time_str = f"{res['infer_time']:.3f}sec"

        if 'lr' in res:
            lr_str = f" lr {res['lr']:4f}"
        else:
            lr_str = ""

        str0 = f"Epoch {self.epoch} {data_str}: l {res['loss']:.3f} a {res['acc']:.3f} clct{res['cl_ct']}{lr_str} {time_str}"

        print(str0)

        prefix= 'Train' if res['is_train']==True else 'Test'
        self.writer.add_scalar("{}/Loss".format(prefix), res['loss'], self.epoch)
        self.writer.add_scalar("{}/Acc".format(prefix), res['acc'], self.epoch)
        self.writer.add_scalar("{}/ClusterAssignAcc".format(prefix), res['cl_acc'],  self.epoch)
        for cluster_indx in range(len(res['cl_ct'])):
           self.writer.add_scalar("{}/Cluster{}Ct".format(prefix,cluster_indx), res['cl_ct'][cluster_indx] , self.epoch)



    def train(self, cluster_assign, lr):
        VERBOSE = 0

        cfg = self.config
        m = cfg['m']
        p = cfg['p']
        tau = cfg['tau']

        # run local update
        t0 = time.time()
        print("train cluster_assign",cluster_assign)

        updated_models = []
        for m_i in range(m):
            if VERBOSE and m_i % 100 == 0: print(f'm {m_i}/{m} processing \r', end ='')

            dataset = self.load_data(m_i)
            training_args = TrainingArguments(output_dir=f"train_client_{m_i}",
                                              num_train_epochs=1,
                                              per_device_train_batch_size=32,
                                              per_device_eval_batch_size=32,
                                              local_rank=-1
                                              )

            p_i = cluster_assign[m_i]
            model = copy.deepcopy(self.models[p_i])

            for step_i in range(tau):
                trainer = Trainer(
                    model=model,
                    args=training_args,
                    train_dataset=dataset,
                    compute_metrics=self.criterion,
                )
                train_log = trainer.train()
                print(train_log)
            updated_models.append(model)

        t02 = time.time()
        # print(f'running single ..took {t02-t01:.3f}sec')


        t1 = time.time()
        if VERBOSE: print(f'local update {t1-t0:.3f}sec')

        # apply gradient update
        t0 = time.time()

        local_models = [[] for p_i in range(p)]
        for m_i in range(m):
            p_i = cluster_assign[m_i]
            local_models[p_i].append(updated_models[m_i])

        for p_i, models in enumerate(local_models):
            if len(models) >0:
                self.global_param_update(models, self.models[p_i]) # p_i is the cluser_id
        t1 = time.time()

        if VERBOSE: print(f'global update {t1-t0:.3f}sec')

    
    def get_inference_stats(self, train = True):
        cfg = self.config
        if train:
            m = cfg['m']
            dataset = self.dataset['train']
        else:
            m = cfg['m_test']
            dataset = self.dataset['test']

        p = cfg['p']


        num_data = 0
        losses = {}
        corrects = {}
        for m_i in range(m):
            eval_dataset = self.load_data(m_i, train=train)
            print(len(eval_dataset))

            for p_i in range(p):
                training_args = TrainingArguments(output_dir=f"get_stat_client_{m_i}_{p_i}",
                                                  num_train_epochs=1,
                                                  per_device_train_batch_size=32,
                                                  per_device_eval_batch_size=32,
                                                  local_rank=-1
                                                  )
                trainer = Trainer(
                    model=self.models[p_i],
                    args=training_args,
                    train_dataset=eval_dataset,
                    eval_dataset=eval_dataset,
                    compute_metrics=self.criterion,
                )
                eval_log = trainer.evaluate()
                print(eval_log)

                losses[(m_i,p_i)] = eval_log['eval_loss']
                corrects[(m_i,p_i)] = eval_log['eval_accuracy']

            num_data += len(eval_dataset)

        # calculate loss and cluster the machines
        cluster_assign = []
        for m_i in range(m):
            machine_losses = [ losses[(m_i,p_i)] for p_i in range(p) ]
            min_p_i = np.argmin(machine_losses)
            cluster_assign.append(min_p_i)

        # calculate optimal model's loss, acc over all models
        min_corrects = []
        min_losses = []
        for m_i, p_i in enumerate(cluster_assign):
            min_loss = losses[(m_i,p_i)]
            min_losses.append(min_loss)
            min_correct = corrects[(m_i,p_i)]
            min_corrects.append(min_correct)

        loss = np.mean(min_losses)
        acc = np.sum(min_corrects) / num_data


        # the order of learned cluster may not be same as dataset order, so can't calculate cluster acc in this way. 
        # cl_acc = np.mean(np.array(cluster_assign) == np.array(dataset['cluster_assign']))
    
        cl_ct = [np.sum(np.array(cluster_assign) == p_i ) for p_i in range(p)]

        res = {} # results
        # res['losses'] = losses
        # res['corrects'] = corrects
        res['cluster_assign'] = cluster_assign
        res['num_data'] = num_data
        res['loss'] = loss
        res['acc'] = acc
        # res['cl_acc'] = cl_acc
        res['cl_ct'] = cl_ct
        res['is_train'] = train


        return res

    def n_correct(self, y_logit, y):
        _, predicted = torch.max(y_logit.data, 1)
        correct = (predicted == y).sum().item()

        return correct

    def load_data(self, m_i, train=True):
        # this part is very fast since its just rearranging models
        cfg = self.config
        m1 = len(cfg['per_cluster_list'][0])

        if train:
            dataset = self.dataset['train']
        else:
            dataset = self.dataset['test']

        indices = dataset['data_indices'][m_i]
        p_i = dataset['cluster_assign'][m_i] # the cluster number -- 0~3 / 0~1/ 0 for k=4/2/1

        # print(indices)
        # print(dataset['dataset'][0])

        if m_i < m1:
            dataset = dataset['dataset'][0].select(indices)
        else:
            dataset = dataset['dataset'][1].select(indices)

        return dataset


    def local_param_update(self, model, lr):

        # gradient update manually

        for name, param in model.named_parameters():
            if param.requires_grad:
                param.data -= lr * param.grad

        model.zero_grad()



    def global_param_update(self, local_models, global_model):

        # average of each weight

        weights = {}

        for m_i, local_model in enumerate(local_models):
            for name, param in local_model.named_parameters():
                if name not in weights:
                    weights[name] = torch.zeros_like(param.data)

                weights[name] += param.data

        for name, param in global_model.named_parameters():
            weights[name] /= len(local_models)
            param.data = weights[name]

    def test(self, train=False):

        return self.get_inference_stats(train=train)

    def save_checkpoint(self, epoch =0 ):
        models_to_save = [model.state_dict() for model in self.models]
        torch.save({'models':models_to_save, 'epoch': epoch}, self.checkpoint_fname)


class SimpleLinear(torch.nn.Module):

    def __init__(self, h1=2048):
        super().__init__()
        self.fc1 = torch.nn.Linear(28*28, h1)
        self.fc2 = torch.nn.Linear(h1, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        # x = F.sigmoid(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

    # def weight(self):
    #     return self.linear1.weight

if __name__ == '__main__':
    start_time = time.time()
    main()
    duration = (time.time() - start_time)
    print("---train cluster Ended in %0.2f hour (%.3f sec) " % (duration/float(3600), duration))