import argparse
import json
import os
import time
import itertools
import pickle
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset, TensorDataset

from torch.utils.tensorboard import SummaryWriter
import numpy as np

from util import *
from load_data import *


# LR_DECAY = True
LR_DECAY = False

def main():

    config = get_config()

    config['train_seed'] = config['data_seed']

    print("config:",config)

    exp = TrainCIFARCluster(config)
    exp.setup()
    exp.run()


def get_config():

    parser = argparse.ArgumentParser()
    parser.add_argument("--project-dir","--projdir" , type=str,default="output/ifca")
    # parser.add_argument("--dataset-dir",type=str,default="output")
    # parser.add_argument("--num-epochs",type=float,default=)
    parser.add_argument("--inference-mode",type=int,default=0)
    parser.add_argument("--lr",type=float,default=0.1)
    parser.add_argument("--data-seed",type=int,default=0)
    parser.add_argument("--train-seed",type=int,default=0)
    parser.add_argument("--config-override",type=str,default="")
    args = parser.parse_args()

    # read config json and update the sysarg
    with open("config.json", "r") as read_file:
        config = json.load(read_file)

    args_dict = vars(args)
    config.update(args_dict)

    if config["config_override"] == "":
        del config['config_override']
    else:
        print(config['config_override'])
        config_override = json.loads(config['config_override'])
        del config['config_override']
        config.update(config_override)

    return config


class TrainCIFARCluster(object):
    def __init__(self, config):
        self.config = config

        assert self.config['m'] % self.config['p'] == 0

    def setup(self):

        os.makedirs(self.config['project_dir'], exist_ok = True)

        self.result_fname = os.path.join(self.config['project_dir'], 'results.pickle')
        self.checkpoint_fname = os.path.join(self.config['project_dir'], 'checkpoint.pt')
        self.writer = SummaryWriter(os.path.join(self.config['project_dir'],'runs'))
        self.setup_datasets()
        self.setup_models()

        self.epoch = None
        self.lr = None


    def setup_datasets(self):

        np.random.seed(self.config['data_seed'])

        # generate indices for each dataset
        # also write cluster info

        CIFAR_TRAINSET_DATA_SIZE = 50000
        CIFAR_TESTSET_DATA_SIZE = 10000

        np.random.seed(self.config['data_seed'])

        cfg = self.config

        self.dataset = {}

        dataset = {}
        (X, y) = self._load_CIFAR(train=True)
        classes2indx = self.build_class_dict(y)
    
        dataset['data_indices'], dataset['cluster_assign'] = \
            self._setup_dataset(CIFAR_TRAINSET_DATA_SIZE, cfg['p'], cfg['m'],classes2indx)
        
        dataset['X'] = X
        dataset['y'] = y
        self.dataset['train'] = dataset

        dataset = {}
        (X, y) = self._load_CIFAR(train=False)
        classes2indx = self.build_class_dict(y)
        dataset['data_indices'], dataset['cluster_assign'] = \
            self._setup_dataset(CIFAR_TESTSET_DATA_SIZE, cfg['p'], cfg['m_test'], classes2indx,  random=False)
        dataset['X'] = X
        dataset['y'] = y
        self.dataset['test'] = dataset

    def build_class_dict (self, labels):
        classes2indx = {}
        for ind, label in enumerate(labels):  # for cifar: 50000; for tinyimagenet: 100000
            label =int(label)
            if label in classes2indx:
                classes2indx[label].append(ind)
            else:
                classes2indx[label] = [ind]
        return classes2indx


    def _setup_dataset(self, num_data, p, m, classes2indx,  random = True):

        if self.config["unbalanced_cluster"]== True:
            cfg = self.config
            data_indices = []
            cluster_assign = []

            ll2 = [[] for _ in range(m)]
            # for class balance
            for key in classes2indx.keys():
                if random:
                    per_class_ll = list(np.random.permutation(classes2indx[key]))
                else:
                    per_class_ll = classes2indx[key]
                per_class_ll2 = chunkify(per_class_ll, m) 
                # print("client:" , 0, "class: ", key, "len:", len(per_class_ll2[0]) )      
                ll2 = [ll2[index]+per_class_ll2[index] for index in range(m)]
            
            for p_i in range(p):
                per_cluster =cfg['per_cluster_list'][p_i] # pre-defined groundtruth client-cluster assignment  
                print("client idx for cluster {}: {}".format(p_i, per_cluster))
                for m_i in per_cluster:
                    data_indices += [ll2[m_i]]
                cluster_assign += [p_i for _ in per_cluster]

            data_indices = np.array(data_indices)
            cluster_assign = np.array(cluster_assign)
            assert data_indices.shape[0] == cluster_assign.shape[0]
            assert data_indices.shape[0] == m

            return data_indices, cluster_assign
        else:
            data_indices = []
            cluster_assign = []

            m_per_cluster = m // p # m is number of clients, p is number of clusters 
            
            for p_i in range(p):
                # for class balance
                ll2 = [[] for _ in range(m_per_cluster)]
                for key in classes2indx.keys():
                    if random:
                        per_class_ll = list(np.random.permutation(classes2indx[key]))
                    else:
                        per_class_ll = classes2indx[key]
                    per_class_ll2 = chunkify(per_class_ll, m_per_cluster)     
                    ll2 = [ll2[index]+per_class_ll2[index] for index in range(m_per_cluster)]
                
                data_indices += ll2
                cluster_assign += [p_i for _ in range(m_per_cluster)]

            data_indices = np.array(data_indices)
            cluster_assign = np.array(cluster_assign)
            assert data_indices.shape[0] == cluster_assign.shape[0]
            assert data_indices.shape[0] == m

            return data_indices, cluster_assign

    def _load_CIFAR(self, train=True):
        transforms = torchvision.transforms.Compose([
                torchvision.transforms.RandomCrop(32, padding=4),
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor()])
        cifar_dataset = datasets.CIFAR10(root="./data", train=train, download=True, transform=transforms)
        dl = DataLoader(cifar_dataset) #(50000+10000,32,32,3)
        X = torch.tensor(dl.dataset.data).to(self.config['device']).permute(0,3,1,2)
        y = torch.tensor(dl.dataset.targets).to(self.config['device'])
        X = X / 255.0
        return X,y


    def setup_models(self):
        np.random.seed(self.config['train_seed'])
        torch.manual_seed(self.config['train_seed'])

        p = self.config['p']

        # self.models = [ SimpleLinear(h1 = self.config['h1']) for p_i in range(p)] # p models with p different params of dimension(1,d)

        self.models = [ resnet(depth=20, num_classes=10).to(self.config['device']) for p_i in range(p)]

        self.criterion = torch.nn.CrossEntropyLoss()
        # import ipdb; ipdb.set_trace()


    def run(self):
        num_epochs = self.config['num_epochs']
        lr = self.config['lr']

        results = []

        # epoch -1
        self.epoch = -1

        result = {}
        result['epoch'] = -1

        t0 = time.time()
        res = self.test(train=True)
        t1 = time.time()
        res['infer_time'] = t1-t0
        result['train'] = res

        self.print_epoch_stats(res)

        t0 = time.time()
        res = self.test(train=False)
        t1 = time.time()
        res['infer_time'] = t1-t0
        result['test'] = res
        self.print_epoch_stats(res)
        results.append(result)

        # this will be used in next epoch
        cluster_assign = result['train']['cluster_assign']

        for epoch in range(num_epochs):
            self.epoch = epoch

            result = {}
            result['epoch'] = epoch

            lr = self.lr_schedule(epoch)
            result['lr'] = lr

            t0 = time.time()
            result['train'] = self.train(cluster_assign, lr = lr)
            t1 = time.time()
            train_time = t1-t0

            t0 = time.time()
            res = self.test(train=True)
            t1 = time.time()
            res['infer_time'] = t1-t0
            res['train_time'] = train_time
            res['lr'] = lr
            result['train'] = res

            self.print_epoch_stats(res)

            t0 = time.time()
            res = self.test(train=False)
            t1 = time.time()
            res['infer_time'] = t1-t0
            result['test'] = res
            self.print_epoch_stats(res)

            results.append(result)

            # this will be used in next epoch's gradient update
            cluster_assign = result['train']['cluster_assign']

            if epoch % 10 == 0 or epoch == num_epochs - 1 :
                with open(self.result_fname, 'wb') as outfile:
                    pickle.dump(results, outfile)
                    print(f'result written at {self.result_fname}')
                self.save_checkpoint(epoch)
                print(f'checkpoint written at {self.checkpoint_fname}')

        # import ipdb; ipdb.set_trace()
        self.writer.close()


    def lr_schedule(self, epoch):
        if self.lr is None:
            self.lr = self.config['lr']

        if epoch % 50 == 0 and epoch != 0 and LR_DECAY:
            self.lr = self.lr * 0.1

        return self.lr


    def print_epoch_stats(self, res):
        if res['is_train']:
            data_str = 'tr'
        else:
            data_str = 'tst'

        if 'train_time' in res:
            time_str = f"{res['train_time']:.3f}sec(train) {res['infer_time']:.3f}sec(infer)"
        else:
            time_str = f"{res['infer_time']:.3f}sec"

        if 'lr' in res:
            lr_str = f" lr {res['lr']:4f}"
        else:
            lr_str = ""

        str0 = f"Epoch {self.epoch} {data_str}: l {res['loss']:.3f} a {res['acc']:.3f} clct{res['cl_ct']}{lr_str} {time_str}"

        print(str0)

        prefix= 'Train' if res['is_train']==True else 'Test'
        self.writer.add_scalar("{}/Loss".format(prefix), res['loss'], self.epoch)
        self.writer.add_scalar("{}/Acc".format(prefix), res['acc'], self.epoch)
        self.writer.add_scalar("{}/ClusterAssignAcc".format(prefix), res['cl_acc'],  self.epoch)
        for cluster_indx in range(len(res['cl_ct'])):
           self.writer.add_scalar("{}/Cluster{}Ct".format(prefix,cluster_indx), res['cl_ct'][cluster_indx] , self.epoch)



    def train(self, cluster_assign, lr):
        VERBOSE = 0

        cfg = self.config
        m = cfg['m']
        p = cfg['p']
        tau = cfg['tau']

        # run local update
        t0 = time.time()
        print("train cluster_assign",cluster_assign)

        updated_models = []
        for m_i in range(m):
            if VERBOSE and m_i % 100 == 0: print(f'm {m_i}/{m} processing \r', end ='')

            loader = self.load_data(m_i)

            p_i = cluster_assign[m_i]
            model = copy.deepcopy(self.models[p_i])

            for step_i in range(tau):

                loss_sum = 0.0
                cnt = 0
                for (X_batch, y_batch) in loader:
                    y_logit = model(X_batch) # the corresponding local model 
                    loss = self.criterion(y_logit, y_batch)
                    model.zero_grad()
                    loss.backward()
                    self.local_param_update(model, lr)

            model.zero_grad()
            updated_models.append(model)

        t02 = time.time()
        # print(f'running single ..took {t02-t01:.3f}sec')


        t1 = time.time()
        if VERBOSE: print(f'local update {t1-t0:.3f}sec')

        # apply gradient update
        t0 = time.time()

        local_models = [[] for p_i in range(p)]
        for m_i in range(m):
            p_i = cluster_assign[m_i]
            local_models[p_i].append(updated_models[m_i])

        for p_i, models in enumerate(local_models):
            if len(models) >0:
                self.global_param_update(models, self.models[p_i]) # p_i is the cluser_id
        t1 = time.time()

        if VERBOSE: print(f'global update {t1-t0:.3f}sec')

    
    def get_inference_stats(self, train = True):
        cfg = self.config
        if train:
            m = cfg['m']
            dataset = self.dataset['train']
        else:
            m = cfg['m_test']
            dataset = self.dataset['test']

        p = cfg['p']


        num_data = 0
        losses = {}
        corrects = {}
        for m_i in range(m):
            loader = self.load_data(m_i, train=train) # load batch data rotated
            for p_i in range(p):
                loss_sum = 0.0
                n_correct = 0
                cnt = 0
                for (X_batch, y_batch) in loader:
                    y_logit = self.models[p_o](X_batch) # the corresponding local model 
                    loss_sum += self.criterion(y_logit, y_batch).item()
                    n_correct += self.n_correct(y_logit, y_batch)
                    cnt += 1
                losses[(m_i,p_i)] = loss_sum/cnt
                corrects[(m_i,p_i)] = n_correct

            num_data += len(loader.dataset)

        # calculate loss and cluster the machines
        cluster_assign = []
        for m_i in range(m):
            machine_losses = [ losses[(m_i,p_i)] for p_i in range(p) ]
            min_p_i = np.argmin(machine_losses)
            cluster_assign.append(min_p_i)

        # calculate optimal model's loss, acc over all models
        min_corrects = []
        min_losses = []
        for m_i, p_i in enumerate(cluster_assign):

            min_loss = losses[(m_i,p_i)]
            min_losses.append(min_loss)

            min_correct = corrects[(m_i,p_i)]
            min_corrects.append(min_correct)

        loss = np.mean(min_losses)
        acc = np.sum(min_corrects) / num_data


        # the order of learned cluster may not be same as dataset order, so can't calculate cluster acc in this way. 
        # cl_acc = np.mean(np.array(cluster_assign) == np.array(dataset['cluster_assign']))
    
        cl_ct = [np.sum(np.array(cluster_assign) == p_i ) for p_i in range(p)]

        res = {} # results
        # res['losses'] = losses
        # res['corrects'] = corrects
        res['cluster_assign'] = cluster_assign
        res['num_data'] = num_data
        res['loss'] = loss
        res['acc'] = acc
        # res['cl_acc'] = cl_acc
        res['cl_ct'] = cl_ct
        res['is_train'] = train


        return res

    def n_correct(self, y_logit, y):
        _, predicted = torch.max(y_logit.data, 1)
        correct = (predicted == y).sum().item()

        return correct

    def load_data(self, m_i, train=True):
        # this part is very fast since its just rearranging models
        cfg = self.config

        if train:
            dataset = self.dataset['train']
        else:
            dataset = self.dataset['test']

        indices = dataset['data_indices'][m_i]
        p_i = dataset['cluster_assign'][m_i] # the cluster number -- 0~3 / 0~1/ 0 for k=4/2/1

        X_batch = dataset['X'][indices]
        y_batch = dataset['y'][indices]

        # if p_i == 1:
        #     X_batch1 = X_batch + torch.randn_like(X_batch)*cfg['sigma'] + cfg['mean']
        #     X_batch1.clamp_(0,1)
        # else:
        #     X_batch1 = X_batch
        k = (p_i % 2) * 2
        X_batch1 = torch.rot90(X_batch, k=int(k), dims = (2,3))
        dataset = local_dataset(X_batch1,y_batch)
        loader = DataLoader(dataset, batch_size=self.config['batch_size'])
        return loader


    def local_param_update(self, model, lr):

        # gradient update manually

        for name, param in model.named_parameters():
            if param.requires_grad:
                param.data -= lr * param.grad

        model.zero_grad()



    def global_param_update(self, local_models, global_model):

        # average of each weight

        weights = {}

        for m_i, local_model in enumerate(local_models):
            for name, param in local_model.named_parameters():
                if name not in weights:
                    weights[name] = torch.zeros_like(param.data)

                weights[name] += param.data

        for name, param in global_model.named_parameters():
            weights[name] /= len(local_models)
            param.data = weights[name]

    def test(self, train=False):

        return self.get_inference_stats(train=train)

    def save_checkpoint(self, epoch =0 ):
        models_to_save = [model.state_dict() for model in self.models]
        torch.save({'models':models_to_save, 'epoch': epoch}, self.checkpoint_fname)



if __name__ == '__main__':
    start_time = time.time()
    main()
    duration = (time.time() - start_time)
    print("---train cluster Ended in %0.2f hour (%.3f sec) " % (duration/float(3600), duration))