import argparse
from cgi import print_form
import json
import os
from re import M
import time
import itertools
import pickle
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset, TensorDataset
import math 
from torch.utils.tensorboard import SummaryWriter

import numpy as np
from train_cluster_mnist import  TrainMNISTCluster,SimpleLinear

from util import *

# LR_DECAY = True
LR_DECAY = False

def main():

    config = get_config()

    config['train_seed'] = config['data_seed']

    print("config:",config)

    exp = TrainMNISTClusterEMAll(config)
    exp.setup()
    if config['inference_mode']==0:
        exp.run()
    else:
        print("inference mode")
        # exp.one_shot_cluster_assignment() # one-shot E step for test pi 
        # exp.inference(save_fname = 'inference_stats_oneshot.json')
        exp.inference(save_fname = 'inference_stats.json') # reuse the train pi as test pi


def get_config():

    parser = argparse.ArgumentParser()
    parser.add_argument("--project-dir","--projdir", type=str,default="output/EM_all")
    # parser.add_argument("--dataset-dir",type=str,default="output")
    # parser.add_argument("--num-epochs",type=float,default=)
    parser.add_argument("--lr",type=float,default=0.1)
    parser.add_argument("--data-seed",type=int,default=0)
    parser.add_argument("--train-seed",type=int,default=0)
    parser.add_argument("--config-override",type=str,default="")

    parser.add_argument("--inference-mode",type=int,default=0)
    parser.add_argument("--checkpoint",type=int,default=0)

    args = parser.parse_args()

    # read config json and update the sysarg
    with open("config.json", "r") as read_file:
        config = json.load(read_file)

    args_dict = vars(args)
    config.update(args_dict)

    if config["config_override"] == "":
        del config['config_override']
    else:
        print(config['config_override'])
        config_override = json.loads(config['config_override'])
        del config['config_override']
        config.update(config_override)

    return config


class TrainMNISTClusterEMAll(TrainMNISTCluster):
    def __init__(self, config):
        self.config = config

        # assert self.config['m'] % self.config['p'] == 0

    def setup(self):
        self.epoch = -1
        os.makedirs(self.config['project_dir'], exist_ok = True)

        self.result_fname = os.path.join(self.config['project_dir'], 'results.pickle')
        self.checkpoint_fname = os.path.join(self.config['project_dir'], 'checkpoint.pt')
        self.writer = SummaryWriter(os.path.join(self.config['project_dir'],'runs'))
        if self.config['checkpoint']:
            self.load_checkpoint_fname = os.path.join(self.config['project_dir'], 'checkpoint.pt')
            

        self.setup_datasets()
        self.setup_models()

        self.lr = None



    def setup_models(self):
        np.random.seed(self.config['train_seed'])
        torch.manual_seed(self.config['train_seed'])

        p = self.config['p']

        self.models = [ SimpleLinear(h1 = self.config['h1']) for p_i in range(p)] # p models with p different params of dimension(1,d)

        self.criterion = torch.nn.CrossEntropyLoss()
        self.agg_criterion = nn.NLLLoss()

        if self.config['checkpoint']:
            states = torch.load(self.load_checkpoint_fname)
            print('loading checkpoint from',self.load_checkpoint_fname,"trained epoch",  states['epoch'])
            for p_i in range(p):
                state_dict = states['models'][p_i]
                model = self.models[p_i]
                model.load_state_dict(state_dict)

            self.soft_cluster_assign = states['soft_assignment']
            self.epoch =  states['epoch']
            print("start from epoch", self.epoch)



    def inference(self, save_fname = 'inference_stats.json'):
        cfg = self.config
        m = cfg['m_test']
        p = cfg['p'] # number of cluster
        
        # calculate loss and cluster the machines
        losses = [[] for m_i in range(m)]
        corrects = [[] for m_i in range(m)]
        num_data = 0
        num_samples =[]
        agg_losses= []
        agg_corrects =[]
        agg_accuracys = []
        for m_i in range(m): # for each client
            (X, y) = self.load_data(m_i, train=False) # load batch data rotated
            agg_loss =0 
            agg_correct =0
            y_pred = 0.

            for p_i in range(p): # for each cluster model 
                y_logit = self.models[p_i](X)
                y_pred += self.soft_cluster_assign[m_i][p_i]* F.softmax(y_logit, dim=1) # for ensemble prediction 

                loss = self.criterion(y_logit, y) # loss of
                n_correct = self.n_correct(y_logit, y)

                losses[m_i].append(loss.item())
                corrects[m_i].append(n_correct)
            
            # ensemble prediction 
            y_pred = torch.clamp(y_pred, min=0., max=1.)
            agg_loss= self.agg_criterion(torch.log(y_pred), y).item()
            agg_correct = self.n_correct(y_pred, y)


            num_data += X.shape[0]
            num_samples.append(X.shape[0])
            agg_losses.append(agg_loss)
            agg_corrects.append(agg_correct)
            agg_accuracys.append(agg_correct/ X.shape[0])
            
            print("client" , m_i , "agg loss", agg_loss, "agg acc",  agg_correct/X.shape[0] )

        inference_stats ={'losses':agg_losses, 'corrects':agg_corrects , 'num_samples':num_samples,  'accuracys':agg_accuracys, 'per_cluster_loss':losses, 'per_cluster_correct': corrects }
        print(inference_stats)
        with open(os.path.join(self.config['project_dir'],save_fname),'w') as out:
            out.write(json.dumps(inference_stats))

    def one_shot_cluster_assignment(self):
        cfg = self.config
        m = cfg['m_test']
        p = cfg['p'] # number of cluster
    
        losses = {}
        num_data = 0
        for m_i in range(m): # for each client
            (X, y) = self.load_data(m_i, train=False) # load batch data rotated
            for p_i in range(p): # for each cluster model 
                y_logit = self.models[p_i](X)
                loss = self.criterion(y_logit, y) # loss of
                losses[(m_i,p_i)] = loss.item()
            num_data += X.shape[0]

        # one -shot E step
        new_soft_assign=  [ [1/p for cluster_idx in range(p)] for client_idx in range(m)]

        for client_i in range(m): # for each client
            exp_term_list= [ new_soft_assign[client_i][cluster_idx] * math.exp(-losses[(client_i,cluster_idx)]) for cluster_idx in range(p) ] 
            exp_term_sum = sum(exp_term_list)
            
            new_soft_assign[client_i] = [exp_term/exp_term_sum for exp_term in exp_term_list]
        self.soft_cluster_assign= new_soft_assign


    def run(self):
        num_epochs = self.config['num_epochs']
        lr = self.config['lr']

        results = []


        result = {}
        result['epoch'] =self.epoch

        self.soft_cluster_assign = [ [1/self.config['p'] for cluster_idx in range(self.config['p'])] for client_idx in range(self.config['m'])]
        
        t0 = time.time()
        res = self.test(train=True) # get cluster_assign on train data 
        t1 = time.time()
        res['infer_time'] = t1-t0
        result['train'] = res

        self.print_epoch_stats(res)

        t0 = time.time()
        res = self.test(train=False) # get cluster_assign on test data 
        t1 = time.time()
        res['infer_time'] = t1-t0
        result['test'] = res
        self.print_epoch_stats(res)
        results.append(result)

        start_epoch= self.epoch
        for epoch in range(start_epoch, num_epochs):
            self.epoch = epoch

            result = {}
            result['epoch'] = epoch

            lr = self.lr_schedule(epoch)
            result['lr'] = lr

            t0 = time.time()
            self.train(lr = lr)
            t1 = time.time()
            train_time = t1-t0

            t0 = time.time()
            res = self.test(train=True)
            t1 = time.time()
            res['infer_time'] = t1-t0
            res['train_time'] = train_time
            res['lr'] = lr
            result['train'] = res

            self.print_epoch_stats(res)

            t0 = time.time()
            res = self.test(train=False)
            t1 = time.time()
            res['infer_time'] = t1-t0
            result['test'] = res
            self.print_epoch_stats(res)

            results.append(result)

            if epoch % 10 == 0 or epoch == num_epochs - 1 :
                with open(self.result_fname, 'wb') as outfile:
                    pickle.dump(results, outfile)
                    print(f'result written at {self.result_fname}')
                self.save_checkpoint(epoch)
                print(f'checkpoint written at {self.checkpoint_fname}')

        # import ipdb; ipdb.set_trace()
        self.writer.close()

    def print_epoch_stats(self, res):
        if res['is_train']:
            data_str = 'tr'
        else:
            data_str = 'tst'

        if 'train_time' in res:
            time_str = f"{res['train_time']:.3f}sec(train) {res['infer_time']:.3f}sec(infer)"
        else:
            time_str = f"{res['infer_time']:.3f}sec"

        if 'lr' in res:
            lr_str = f" lr {res['lr']:4f}"
        else:
            lr_str = ""

        str0 = f"Epoch {self.epoch} {data_str}: l {res['loss']:.3f} a {res['acc']:.3f} clct{res['cl_ct']}{lr_str} {time_str}"

        print(str0)

        prefix= 'Train' if res['is_train']==True else 'Test'
        self.writer.add_scalar("{}/Loss".format(prefix), res['loss'], self.epoch)
        self.writer.add_scalar("{}/Acc".format(prefix), res['acc'], self.epoch)
        # self.writer.add_scalar("{}/ClusterAssignAcc".format(prefix), res['cl_acc'],  self.epoch)
        for cluster_indx in range(len(res['cl_ct'])):
            self.writer.add_scalar("{}/Cluster{}Ct".format(prefix,cluster_indx), res['cl_ct'][cluster_indx] , self.epoch)

    
    def train(self, lr):
        VERBOSE = 0

        cfg = self.config
        m = cfg['m']
        p = cfg['p']
        tau = cfg['tau']

        # run local update
        t0 = time.time()

        # update each cluster     #  [8, 2]
        for p_i in range(p):      #  p = 4
            updated_models = []
            for m_i in range(m):   # m = 10 # each client
                if VERBOSE and m_i % 100 == 0: print(f'm {m_i}/{m} processing \r', end ='')

                (X, y) = self.load_data(m_i)
                model = copy.deepcopy(self.models[p_i])

                for step_i in range(tau):
                    y_logit = model(X)
                    loss = self.criterion(y_logit, y)
                    model.zero_grad()
                    loss.backward()
                    self.local_param_update(model, lr)
                # print("cluster {} client {} loss {}".format(p_i, m_i, loss.item()))
                model.zero_grad()
                updated_models.append(model)
            
            self.global_param_update_soft(updated_models, self.models[p_i],p_i) # p_i is the cluser_id


        t1 = time.time()
        if VERBOSE: print(f'all cluster update {t1-t0:.3f}sec')
  

    def get_inference_stats_soft(self, train = True):
        cfg = self.config
        if train:
            m = cfg['m'] # number of clients
            dataset = self.dataset['train']
        else:
            m = cfg['m_test']
            dataset = self.dataset['test']

        p = cfg['p'] # number of cluster
        
        # calculate loss and cluster the machines
        
        losses = {}
        corrects = {}
        all_logits ={}
        num_data = 0
        for m_i in range(m): # for each client
            (X, y) = self.load_data(m_i, train=train) # load batch data rotated
            for p_i in range(p): # for each cluster model 
                y_logit = self.models[p_i](X)
                loss = self.criterion(y_logit, y) # loss of
                n_correct = self.n_correct(y_logit, y)

                all_logits[(m_i,p_i)] = y_logit
                losses[(m_i,p_i)] = loss.item()
                corrects[(m_i,p_i)] = n_correct

            num_data += X.shape[0]


        if train: # if train, update the soft assignment 
            # E step 
            new_soft_assign= copy.deepcopy(self.soft_cluster_assign)
        else:
            # one -shot E step
            new_soft_assign=  [ [1/p for cluster_idx in range(p)] for client_idx in range(m)]
        
        for client_i in range(m): # for each client
            exp_term_list= [ new_soft_assign[client_i][cluster_idx] * math.exp(-losses[(client_i,cluster_idx)]) for cluster_idx in range(p) ] # why [0.0, 0.0, 0.0, 0.0]?
            exp_term_sum = sum(exp_term_list)
            
            new_soft_assign[client_i] = [exp_term/exp_term_sum for exp_term in exp_term_list]
        if train: 
            self.soft_cluster_assign = new_soft_assign
    

        
        agg_losses =[]
        agg_corrects =[]
        num_data = 0
        for m_i in range(m): # for each client
            (X, y) = self.load_data(m_i, train=train) # load batch data rotated

            y_pred = 0.
            for p_i in range(p): # for each cluster model 
                y_logit = all_logits[(m_i,p_i)] 
                y_pred += new_soft_assign[m_i][p_i]* F.softmax(y_logit, dim=1) # for ensemble prediction 

        
            # ensemble prediction 
            y_pred = torch.clamp(y_pred, min=0., max=1.)
           
            agg_loss= self.agg_criterion(torch.log(y_pred), y).item()
            agg_correct = self.n_correct(y_pred, y)
            num_data += X.shape[0]
            agg_losses.append(agg_loss)
            agg_corrects.append(agg_correct)
            
            # print("client" , m_i , "agg loss", agg_loss, "agg acc",  agg_correct/X.shape[0] )

        loss = np.mean(agg_losses)
        acc = np.sum(agg_corrects) / num_data



        cluster_assign = [] # get hard cluster assign 
        for client_i in range(m):
            max_cluster_idx = np.argmax(new_soft_assign[client_i])
            cluster_assign.append(max_cluster_idx)
        # check cluster assignment acc
        #  the order of learned cluster may not be same as dataset order, so can't calculate cluster acc in this way. 
        # cl_acc = np.mean(np.array(cluster_assign) == np.array(dataset['cluster_assign']))
        cl_ct = [np.sum(np.array(cluster_assign) == p_i ) for p_i in range(p)]
        print("argmax hard cluster assigment", cluster_assign )
        res = {} # results
        # res['losses'] = losses
        # res['corrects'] = corrects
        res['cluster_assign'] = cluster_assign
        res['num_data'] = num_data
        res['loss'] = loss
        res['acc'] = acc
        # res['cl_acc'] = cl_acc
        res['cl_ct'] = cl_ct
        res['is_train'] = train

        # import ipdb; ipdb.set_trace()

        return res



    def global_param_update_soft(self, local_models, global_model, cluster_idx ):

        # average of each weight
        
        agg_weights= [self.soft_cluster_assign[client_i][cluster_idx] for client_i in range(len(local_models))]
        agg_weights_sum = sum(agg_weights)
        agg_weights_normalized = [x/agg_weights_sum for x in agg_weights ]

        weights = {}

        for m_i, local_model in enumerate(local_models):
            for name, param in local_model.named_parameters():
                if name not in weights:
                    weights[name] = torch.zeros_like(param.data)

                weights[name] += param.data * agg_weights_normalized[m_i]

        for name, param in global_model.named_parameters():
            param.data = weights[name]



    def test(self, train=False):

        return self.get_inference_stats_soft(train=train)

    def save_checkpoint(self, epoch=0 ):
        models_to_save = [model.state_dict() for model in self.models]
        torch.save({'models':models_to_save, 'soft_assignment':self.soft_cluster_assign, 'epoch':epoch}, self.checkpoint_fname)




if __name__ == '__main__':
    start_time = time.time()
    main()
    duration = (time.time() - start_time)
    print("---train cluster Ended in %0.2f hour (%.3f sec) " % (duration/float(3600), duration))