import argparse
import json
import os
import time
import itertools
import pickle
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset, TensorDataset

from torch.utils.tensorboard import SummaryWriter
import numpy as np

from cifar_resnet import resnet
from train_cluster_cifar import *
from util import *



# LR_DECAY = True
LR_DECAY = False

def main():

    config = get_config()
    print("config:",config)
    config['train_seed'] = config['data_seed']

    exp = TrainCIFARClusterFedprox(config)
    exp.setup()
    if config['inference_mode']==0:
        exp.run()
    else:
        print("inference mode")
        exp.inference()


def get_config():
    parser = argparse.ArgumentParser()
    # parser.add_argument("--project-dir",type=str,default="output")
    parser.add_argument("--project-dir","--projdir" , type=str,default="output/fedavg")
    # parser.add_argument("--dataset-dir",type=str,default="output")
    # parser.add_argument("--num-epochs",type=float,default=)
   
    parser.add_argument("--lr",type=float,default=0.1)
    parser.add_argument("--mu",type=float,default=1)
    parser.add_argument("--train-seed",type=int,default=0)
    parser.add_argument("--config-override",type=str,default="")

    parser.add_argument("--inference-mode",type=int,default=0)
    parser.add_argument("--checkpoint",type=int,default=0)

    args = parser.parse_args()
    # read config json and update the sysarg
    with open("config.json", "r") as read_file:
        config = json.load(read_file)

    args_dict = vars(args)
    config.update(args_dict)

    if config["config_override"] == "":
        del config['config_override']
    else:
        print(config['config_override'])
        config_override = json.loads(config['config_override'])
        del config['config_override']
        config.update(config_override)

    return config


class TrainCIFARClusterFedprox(TrainCIFARCluster):
    def setup(self):
        self.epoch = -1 
        os.makedirs(self.config['project_dir'], exist_ok = True)

        self.mu = self.config['mu']
        self.result_fname = os.path.join(self.config['project_dir'], 'results_single.pickle')
        self.checkpoint_fname = os.path.join(self.config['project_dir'], 'checkpoint_single.pt')
        if self.config['checkpoint']:
            self.load_checkpoint_fname = os.path.join(self.config['project_dir'], 'checkpoint_single.pt')
        
        self.writer = SummaryWriter(os.path.join(self.config['project_dir'],'runs'))
        self.setup_datasets()
        self.setup_models()


        self.lr = None


    def setup_models(self):
        np.random.seed(self.config['train_seed'])
        torch.manual_seed(self.config['train_seed'])

        p = self.config['p']

        self.models = [ resnet(depth=20, num_classes=10).to(self.config['device']) for p_i in range(p)]

        self.criterion = torch.nn.CrossEntropyLoss()

        if self.config['checkpoint']:
            print('loading checkpoint from',self.load_checkpoint_fname )

            states = torch.load(self.load_checkpoint_fname,map_location="cuda:3")

            state_dict = states['models'][0] 
            model = self.models[0]
            model.load_state_dict(state_dict)
            self.epoch =  states['epoch']



    def inference(self):
        cfg = self.config
        m = cfg['m_test']
        num_data = 0
        losses = []
        corrects = []
        num_samples =[]
        accuracys = []
        for m_i in range(m):
            loader = self.load_data(m_i, train=False) # load batch data rotated
            loss_sum = 0.0
            cnt = 0
            n_correct = 0
            for (X_batch, y_batch) in loader:
                y_logit = self.models[0](X_batch) # the global model
                loss_sum += self.criterion(y_logit, y_batch).detach().cpu().item() # loss of
                n_correct += self.n_correct(y_logit, y_batch)
                cnt += 1
            losses.append(loss_sum/cnt)
            corrects.append(n_correct)
            num_samples.append(len(loader.dataset))
            accuracys.append(n_correct / len(loader.dataset))
            num_data += len(loader.dataset)
        loss = np.mean(losses)
        acc = np.sum(corrects) / num_data
        print("loss", loss, "acc",  acc)

        inference_stats ={'losses':losses, 'corrects':corrects , 'num_samples':num_samples,  'accuracys':accuracys}
        print(inference_stats)
        with open(os.path.join(self.config['project_dir'], 'inference_stats.json'),'w') as out:
            out.write(json.dumps(inference_stats))




    def get_inference_stats(self, train = True):
        cfg = self.config
        if train:
            m = cfg['m']
            dataset = self.dataset['train']
        else:
            m = cfg['m_test']
            dataset = self.dataset['test']

        p = cfg['p']


        num_data = 0
        losses = {}
        corrects = {}
        for m_i in range(m):
            loader = self.load_data(m_i, train=train) # load batch data rotated
            p_i=0
            # for p_i in range(p):
            loss_sum = 0.0
            cnt = 0
            n_correct = 0
            for (X_batch, y_batch) in loader:
                y_logit = self.models[p_i](X_batch)
                loss_sum += self.criterion(y_logit, y_batch).detach().cpu().item() # loss of
                n_correct += self.n_correct(y_logit, y_batch)
                cnt += 1

            losses[(m_i,p_i)] = loss_sum/cnt
            corrects[(m_i,p_i)] = n_correct

            num_data += len(loader.dataset)


        # calculate loss and cluster the machines
        cluster_assign = []
        for m_i in range(m):
            min_p_i = 0 ########### force this #############
           
            cluster_assign.append(min_p_i)

        # calculate optimal model's loss, acc over all models
        min_corrects = []
        min_losses = []
        for m_i, p_i in enumerate(cluster_assign): # iter over clients, each client use 0 cluster  assignment 

            min_loss = losses[(m_i,p_i)]
            min_losses.append(min_loss)

            min_correct = corrects[(m_i,p_i)]
            min_corrects.append(min_correct)


        loss = np.mean(min_losses)
        acc = np.sum(min_corrects) / num_data


        # check cluster assignment acc
        cl_acc = np.mean(np.array(cluster_assign) == np.array(dataset['cluster_assign']))
        cl_ct = [np.sum(np.array(cluster_assign) == p_i ) for p_i in range(p)]

        res = {} # results
        # res['losses'] = losses
        # res['corrects'] = corrects
        res['cluster_assign'] = cluster_assign
        res['num_data'] = num_data
        res['loss'] = loss
        res['acc'] = acc
        res['cl_acc'] = cl_acc
        res['cl_ct'] = cl_ct
        res['is_train'] = train

        # import ipdb; ipdb.set_trace()

        return res



    def train(self, cluster_assign, lr):
        VERBOSE = 0

        cfg = self.config
        m = cfg['m']
        p = cfg['p']
        tau = cfg['tau']

        # run local update
        t0 = time.time()
        print("train fexprox!!!! cluster_assign",cluster_assign)

        updated_models = []
        for m_i in range(m):
            if VERBOSE and m_i % 100 == 0: print(f'm {m_i}/{m} processing \r', end ='')

            loader = self.load_data(m_i)

            p_i = cluster_assign[m_i]
            model = copy.deepcopy(self.models[p_i])
            src_model = copy.deepcopy(model)
            # freeze the parameters of src model 
            for param in src_model.parameters():
                param.requires_grad = False
            
            for step_i in range(tau):

                loss = 0.0
                cnt = 0
                for (X_batch, y_batch) in loader:
                    y_logit = model(X_batch) # the corresponding local model 
                    loss += self.criterion(y_logit, y_batch)
                    cnt += 1
                loss = loss / cnt
                loss_proximal = 0
                for pm, ps in zip(model.parameters(), src_model.parameters()):
                    loss_proximal += torch.sum(torch.pow(pm - ps, 2))
                loss = loss + 0.5 * self.mu * loss_proximal
                model.zero_grad()
                loss.backward()
                self.local_param_update(model, lr)

            model.zero_grad()

            updated_models.append(model)

        t02 = time.time()
        # print(f'running single ..took {t02-t01:.3f}sec')


        t1 = time.time()
        if VERBOSE: print(f'local update {t1-t0:.3f}sec')

        # apply gradient update
        t0 = time.time()

        local_models = [[] for p_i in range(p)]
        for m_i in range(m):
            p_i = cluster_assign[m_i]
            local_models[p_i].append(updated_models[m_i])

        for p_i, models in enumerate(local_models):
            if len(models) >0:
                self.global_param_update(models, self.models[p_i]) # p_i is the cluser_id
        t1 = time.time()

        if VERBOSE: print(f'global update {t1-t0:.3f}sec')


if __name__ == '__main__':
    start_time = time.time()
    main()
    duration = (time.time() - start_time)
    print("---train cluster single Ended in %0.2f hour (%.3f sec) " % (duration/float(3600), duration))