import argparse
import json
import os
import time
import itertools
import pickle
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.utils.tensorboard import SummaryWriter

import numpy as np

from train_cluster_cifar import *
from util import *
from cifar_resnet import resnet


# LR_DECAY = True
LR_DECAY = False

def main():


    config = get_config()
    print("config:",config)
    config['train_seed'] = config['data_seed']

    exp = TrainCIFARClusterLocal(config)
    exp.setup()
    if config['inference_mode']==0:
        exp.run()
    else:
        print("inference mode")
        exp.inference()


def get_config():
    parser = argparse.ArgumentParser()
    parser.add_argument("--project-dir","--projdir" , type=str,default="output/localtrain")
    # parser.add_argument("--dataset-dir",type=str,default="output")
    # parser.add_argument("--num-epochs",type=float,default=)
    parser.add_argument("--lr",type=float,default=0.1)
    # parser.add_argument("--num-epochs",type=int,default=10)
    
    parser.add_argument("--train-seed",type=int,default=0)
    parser.add_argument("--config-override",type=str,default="")

    parser.add_argument("--inference-mode",type=int,default=0)
    parser.add_argument("--checkpoint",type=int,default=0)
    args = parser.parse_args()
    # read config json and update the sysarg
    with open("config.json", "r") as read_file:
        config = json.load(read_file)

    args_dict = vars(args)
    config.update(args_dict)

    if config["config_override"] == "":
        del config['config_override']
    else:
        print(config['config_override'])
        config_override = json.loads(config['config_override'])
        del config['config_override']
        config.update(config_override)

    return config

class TrainCIFARClusterLocal(TrainCIFARCluster):
    def setup(self):

        os.makedirs(self.config['project_dir'], exist_ok = True)

        if self.config['checkpoint']:
            # self.result_fname = os.path.join(self.config['project_dir'], 'results_ckpt_local.pickle')
            self.load_checkpoint_fname = os.path.join(self.config['project_dir'], 'checkpoint_local.pt')

        self.result_fname = os.path.join(self.config['project_dir'], 'results_local.pickle')

        
        self.checkpoint_fname = os.path.join(self.config['project_dir'], 'checkpoint_local.pt')
        self.writer = SummaryWriter(os.path.join(self.config['project_dir'],'runs'))

        self.setup_datasets()
        self.setup_models()

        self.epoch = None
        self.lr = None

    def setup_models(self):
        np.random.seed(self.config['train_seed'])
        torch.manual_seed(self.config['train_seed'])

        m = self.config['m']

        self.models = [ resnet(depth=20, num_classes=10).to(self.config['device']) for m_i in range(m)]

        self.criterion = torch.nn.CrossEntropyLoss()

        if self.config['checkpoint']:
            print('loading checkpoint from',self.load_checkpoint_fname )

            states = torch.load(self.load_checkpoint_fname)
        
            for m_i in range(m):
                state_dict = states['models'][m_i]
                model = self.models[m_i]
                model.load_state_dict(state_dict)

        # import ipdb; ipdb.set_trace()

    def inference(self):
        cfg = self.config
        m = cfg['m_test']
        num_data = 0
        losses = []
        corrects = []
        num_samples =[]
        accuracys = []
        for m_i in range(m):
            loader = self.load_data(m_i, train=False) # load batch data rotated
            n_corrects = 0
            loss_sum = 0.0
            cnt = 0
            for (X_batch, y_batch) in loader:
                y_logit = self.models[m_i](X_batch) # the corresponding local model 
                loss = self.criterion(y_logit, y_batch)
                n_correct = self.n_correct(y_logit, y_batch)
                loss_sum += loss.detach().cpu().item()
                cnt += 1
                n_corrects += n_correct
            losses.append(loss_sum/cnt)
            corrects.append(n_corrects)
            num_samples.append(len(loader.dataset))
            accuracys.append(n_corrects/len(loader.dataset))
            num_data += len(loader.dataset)
        
        loss = np.mean(losses)
        acc = np.sum(corrects) / num_data
        print("loss", loss, "acc",  acc)

       
        inference_stats ={'losses':losses, 'corrects':corrects , 'num_samples':num_samples,  'accuracys':accuracys}
        print(inference_stats)
        with open(os.path.join(self.config['project_dir'], 'inference_stats.json'),'w') as out:
            out.write(json.dumps(inference_stats))


    def run(self):
        num_epochs = self.config['num_epochs']
        print("num_epochs", num_epochs)
        lr = self.config['lr']

        results = []

        # epoch -1
        self.epoch = -1

        result = {}
        result['epoch'] = -1

        t0 = time.time()
        res = self.test(train=True)
        t1 = time.time()
        res['infer_time'] = t1-t0
        result['train'] = res

        self.print_epoch_stats(res)

        t0 = time.time()
        res = self.test(train=False)
        t1 = time.time()
        res['infer_time'] = t1-t0
        result['test'] = res
        self.print_epoch_stats(res)
        results.append(result)

        # this will be used in next epoch
        cluster_assign = result['train']['cluster_assign']

        for epoch in range(num_epochs):
            self.epoch = epoch

            result = {}
            result['epoch'] = epoch

            lr = self.lr_schedule(epoch)
            result['lr'] = lr

            t0 = time.time()
            result['train'] = self.train(cluster_assign, lr = lr)
            t1 = time.time()
            train_time = t1-t0

            t0 = time.time()
            res = self.test(train=True)
            t1 = time.time()
            res['infer_time'] = t1-t0
            res['train_time'] = train_time
            res['lr'] = lr
            result['train'] = res

            self.print_epoch_stats(res)

            # if epoch % 2 == 0 and epoch !=0:
            if 1:

                t0 = time.time()
                res = self.test(train=False)
                t1 = time.time()
                res['infer_time'] = t1-t0
                result['test'] = res
                self.print_epoch_stats(res)

            results.append(result)

            # this will be used in next epoch's gradient update
            cluster_assign = result['train']['cluster_assign']

            if epoch == num_epochs - 1 :
                with open(self.result_fname, 'wb') as outfile:
                    pickle.dump(results, outfile)
                    print(f'result written at {self.result_fname}')
            self.save_checkpoint(epoch)
            print(f'checkpoint written at {self.checkpoint_fname}')

    def get_inference_stats(self, train = True):
        VERBOSE=1

        cfg = self.config
        if train:
            m = cfg['m']
        else:
            m = cfg['m_test']

        num_data = 0
        losses = []
        corrects = []
        for m_i in range(m):
            loader = self.load_data(m_i, train=train) # load batch data rotated
            loss_sum = 0.0
            cnt = 0
            n_correct = 0
            for (X_batch, y_batch) in loader:
                y_logit = self.models[m_i](X_batch) # the corresponding local model 
                loss_sum += self.criterion(y_logit, y_batch).detach().cpu().item()
                n_correct += self.n_correct(y_logit, y_batch)
                cnt += 1

            losses.append(loss_sum/cnt)
            corrects.append(n_correct)

            num_data += len(loader.dataset)

        loss = np.mean(losses)
        acc = np.sum(corrects) / num_data

        # else: # test
        #     m_test = cfg['m_test']
        #     m = cfg['m']

        #     # for each model, get accuracy for all same cluster model

           
        #     accs = []
        #     losses = []
        #     for m_i in range(m):
        #         if VERBOSE and m_i % 1 == 0: print(f'm {m_i}/{m} processing \r', end ='')
        #         model = self.models[m_i]
        #         model_p = self.dataset['train']['cluster_assign'][m_i] # get the cluster 
                

        #         (avg_loss, acc) = loss_correct(cfg, self.dataset, self.criterion ,model, model_p, train)

        #         accs.append(acc)
        #         losses.append(avg_loss)


        #     # get loss, acc averaged over models
        #     loss = np.mean(losses)
        #     acc = np.mean(accs)


        cluster_assign = []
        # for m_i in range(m):
        #     cluster_assign.append(0)

        # check cluster assignment acc
        cl_acc = 0
        cl_ct = []

        res = {} # results
        res['cluster_assign'] = cluster_assign # dummy
        res['loss'] = loss
        res['acc'] = acc
        res['cl_acc'] = cl_acc # dummy
        res['cl_ct'] = cl_ct # dummy
        res['is_train'] = train

        # import ipdb; ipdb.set_trace()

        return res


    def train(self, cluster_assign, lr):
        # cluster_assign is dummy
        VERBOSE = 0

        cfg = self.config
        m = cfg['m']
        p = cfg['p']
        tau = cfg['tau']

        # run local update
        t0 = time.time()

        for m_i in range(m):
            if VERBOSE and m_i % 100 == 0: print(f'm {m_i}/{m} processing \r', end ='')

            loader = self.load_data(m_i)

            model = self.models[m_i]

            for step_i in range(tau):
                for (X_batch, y_batch) in loader:
                    y_logit = model(X_batch) # the corresponding local model 
                    loss = self.criterion(y_logit, y_batch)
                    model.zero_grad()
                    loss.backward()
                    self.local_param_update(model, lr)
            model.zero_grad()

        t1 = time.time()
        if VERBOSE: print(f'local update {t1-t0:.3f}sec')




if __name__ == '__main__':
    start_time = time.time()
    main()
    duration = (time.time() - start_time)
    print("---train cluster single Ended in %0.2f hour (%.3f sec) " % (duration/float(3600), duration))