from datetime import datetime
import pytz
import os
import torch
from scipy import stats
import numpy as np
from scipy import io
import torch
import torchvision
import torchvision.transforms as transforms
import copy


def save_img(plt, name, opt):

    current_datetime = datetime.now(pytz.timezone('US/East-Indiana'))
    current_date_time = current_datetime.strftime("%m-%d-%H-%M")
    current_date = current_datetime.strftime("%m-%d")
    dataset_name = opt.dataset
    agent_number = "%d_agents" % opt.n_agents
    sample_rate = "rate_%.2f" % opt.sample
    dir_name = os.path.join(opt.result_path,dataset_name, opt.model, agent_number, sample_rate, current_date)
    filename = name + '-' + current_date_time + '.jpg'
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    plt.savefig(os.path.join(dir_name, filename), dpi=300, bbox_inches="tight")

def save_acc(dictionary, name, opt):
    current_datetime = datetime.now(pytz.timezone('US/East-Indiana'))
    current_date_time = current_datetime.strftime("%m-%d-%H-%M")
    current_date = current_datetime.strftime("%m-%d")
    dataset_name = opt.dataset
    agent_number = "%d_agents" % opt.n_agents
    sample_rate = "rate_%.2f" % opt.sample
    dir_name = os.path.join('new_' + opt.result_path,dataset_name, opt.model, agent_number, sample_rate, current_date)
    filename = name + '.npy'
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    np.save(os.path.join(dir_name, filename), dictionary)
    # plt.savefig(os.path.join(dir_name, filename), dpi=300, bbox_inches="tight")



def save_plt(plt_list, name, opt):
    current_datetime = datetime.now(pytz.timezone('US/East-Indiana'))
    current_date_time = current_datetime.strftime("%m-%d-%H-%M")
    current_date = current_datetime.strftime("%m-%d")
    dir_name = os.path.join(opt.result_path, current_date)
    filename = name + '-' + current_date_time + '.pth'
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    torch.save(plt_list, os.path.join(dir_name, filename))


class Dataset(torch.utils.data.Dataset):
    
    def __init__(self, data_x, data_y=True, train=False, dataset_name=''):
        self.name = dataset_name
        if self.name == 'MNIST' or self.name == 'synt' or self.name == 'emnist':
            self.X_data = torch.tensor(data_x).float()
            self.y_data = data_y
            if not isinstance(data_y, bool):
                self.y_data = torch.tensor(data_y).float()
            
        elif self.name == 'CIFAR10' or self.name == 'CIFAR100' or self.name == 'TinyImage':
            self.train = train
            self.transform = transforms.Compose([transforms.ToTensor()])


            # extended_img = np.zeros((data_x.shape[0],3,64,64)).astype(np.float32)
            # self.X_data = extended_img
            # self.X_data[:,:, :32,:32] = data_x
            self.X_data = data_x
            self.y_data = data_y
            # pdb.set_trace()
            if not isinstance(data_y, bool):
                self.y_data = data_y.astype('float32')
        
                
        elif self.name == 'shakespeare':
            
            self.X_data = data_x
            self.y_data = data_y
                
            self.X_data = torch.tensor(self.X_data).long()
            if not isinstance(data_y, bool):
                self.y_data = torch.tensor(self.y_data).float()
            
    def __len__(self):
        return len(self.X_data)

    def __getitem__(self, idx):
        if self.name == 'mnist' or self.name == 'synt' or self.name == 'emnist':
            X = self.X_data[idx, :]
            if isinstance(self.y_data, bool):
                return X
            else:
                y = self.y_data[idx]
                return X, y
        
        elif self.name == 'TinyImage':
            img = self.X_data[idx]
            if self.train:
                img = np.flip(img, axis=2).copy() if (np.random.rand() > .5) else img # Horizontal flip
                if (np.random.rand() > .5):
                    # Random cropping 
                    pad = 4
                    extended_img = np.zeros((3,64 + pad *2, 64 + pad *2)).astype(np.float32)
                    extended_img[:,pad:-pad,pad:-pad] = img
                    dim_1, dim_2 = np.random.randint(pad * 2 + 1, size=2)
                    img = extended_img[:,dim_1:dim_1+64,dim_2:dim_2+64]
            # pdb.set_trace()
            img = np.moveaxis(img, 0, -1)
            img = self.transform(img)
            if isinstance(self.y_data, bool):
                return img
            else:
                y = self.y_data[idx]
                return img, y
            
        elif self.name == 'CIFAR10' or self.name == 'CIFAR100':
            img = self.X_data[idx]
            # if not self.train:
            #     extended_img = np.zeros((3,64 , 64 )).astype(np.float32)
            #     extended_img[:,:64,:64] = img
            #     img = extended_img
            if self.train:
                img = np.flip(img, axis=2).copy() if (np.random.rand() > .5) else img # Horizontal flip
                if (np.random.rand() > .5):
                # Random cropping 
                    pad = 4
                    extended_img = np.zeros((3,32 + pad *2, 32 + pad *2)).astype(np.float32)
                    extended_img[:,pad:-pad,pad:-pad] = img
                    dim_1, dim_2 = np.random.randint(pad * 2 + 1, size=2)
                    img = extended_img[:,dim_1:dim_1+32,dim_2:dim_2+32]
                    # extended_img = np.zeros((3,64 + pad *2, 64 + pad *2)).astype(np.float32)
                    # extended_img[:,pad:-pad,pad:-pad] = img
                    # dim_1, dim_2 = np.random.randint(pad * 2 + 1, size=2)
                    # img = extended_img[:,dim_1:dim_1+64,dim_2:dim_2+64]
            img = np.moveaxis(img, 0, -1)
            img = self.transform(img)
            # pdb.set_trace()
            if isinstance(self.y_data, bool):
                return img
            else:
                y = self.y_data[idx]
                return img, y
            
        elif self.name == 'shakespeare':
            x = self.X_data[idx]
            y = self.y_data[idx] 
            return x, y
        

class DatasetObject:
    def __init__(self, dataset, n_client, rule, unbalanced_sgm=0, rule_arg=''):
        self.dataset  = dataset
        self.n_client = n_client

        self.rule     = rule
        self.rule_arg = rule_arg
        rule_arg_str  = rule_arg if isinstance(rule_arg, str) else '%.3f' % rule_arg
        self.name = "%s_%d_%s_%s" %(self.dataset, self.n_client, self.rule, rule_arg_str)
        self.name += '_%f' %unbalanced_sgm if unbalanced_sgm!=0 else ''
        self.unbalanced_sgm = unbalanced_sgm
        self.data_path = '../data'

        
        # Define training and validation data paths


        self.set_data()
        
    def set_data(self):
        # Prepare data if not ready
        if not os.path.exists('%s/%s' %(self.data_path, self.name)):
                
            if self.dataset == 'CIFAR100':
                print(self.dataset)
                # mean and std are validated here: https://gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151
                transform = transforms.Compose([transforms.ToTensor(),
                                                transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], 
                                                                     std=[0.2675, 0.2565, 0.2761])])

                trnset = torchvision.datasets.CIFAR100(root='%s/Raw' %self.data_path,
                                                      train=True , download=True, transform=transform)
                tstset = torchvision.datasets.CIFAR100(root='%s/Raw' %self.data_path,
                                                      train=False, download=True, transform=transform)
                trn_load = torch.utils.data.DataLoader(trnset, batch_size=50000, shuffle=False, num_workers=0)
                tst_load = torch.utils.data.DataLoader(tstset, batch_size=10000, shuffle=False, num_workers=0)
                self.channels = 3; self.width = 32; self.height = 32; self.n_cls = 100;
            
                trn_itr = trn_load.__iter__(); tst_itr = tst_load.__iter__() 
                # labels are of shape (n_data,)
                trn_x, trn_y = trn_itr.__next__()
                tst_x, tst_y = tst_itr.__next__()

                trn_x = trn_x.numpy(); trn_y = trn_y.numpy().reshape(-1,1)
                tst_x = tst_x.numpy(); tst_y = tst_y.numpy().reshape(-1,1)
            
            # Shuffle Data
            rand_perm = np.random.permutation(len(trn_y))
            trn_x = trn_x[rand_perm]
            trn_y = trn_y[rand_perm]
            
            self.trn_x = trn_x
            self.trn_y = trn_y
            self.tst_x = tst_x
            self.tst_y = tst_y
            
            ###
            n_data_per_clnt = int((len(trn_y)) / self.n_client)
            if self.unbalanced_sgm != 0:
                # Draw from lognormal distribution
                clnt_data_list = (np.random.lognormal(mean=np.log(n_data_per_clnt), sigma=self.unbalanced_sgm, size=self.n_client))
                clnt_data_list = (clnt_data_list/np.sum(clnt_data_list)*len(trn_y)).astype(int)
                diff = np.sum(clnt_data_list) - len(trn_y)

                # Add/Subtract the excess number starting from first client
                if diff!= 0:
                    for clnt_i in range(self.n_client):
                        if clnt_data_list[clnt_i] > diff:
                            clnt_data_list[clnt_i] -= diff
                            break
            else:
                clnt_data_list = (np.ones(self.n_client) * n_data_per_clnt).astype(int)
            ###
            if self.rule == 'Dirichlet':
                cls_priors   = np.random.dirichlet(alpha=[self.rule_arg]*self.n_cls,size=self.n_client)
                prior_cumsum = np.cumsum(cls_priors, axis=1)
                idx_list = [np.where(trn_y==i)[0] for i in range(self.n_cls)]
                cls_amount = [len(idx_list[i]) for i in range(self.n_cls)]

                clnt_x = [ np.zeros((clnt_data_list[clnt__], self.channels, self.height, self.width)).astype(np.float32) for clnt__ in range(self.n_client) ]
                clnt_y = [ np.zeros((clnt_data_list[clnt__], 1)).astype(np.int64) for clnt__ in range(self.n_client) ]


                init_size = np.sum(clnt_data_list)
                while(np.sum(clnt_data_list)!=0):
                    curr_clnt = np.random.randint(self.n_client)
                    # If current node is full resample a client
                    if np.sum(clnt_data_list) % 200 == 0 or np.sum(clnt_data_list) <= 100:
                        print('Remaining Data: %d' %np.sum(clnt_data_list), '\t\t\t\t\t\r', end = '')
                    if clnt_data_list[curr_clnt] <= 0:
                        continue
                    clnt_data_list[curr_clnt] -= 1
                    curr_prior = prior_cumsum[curr_clnt]
                    if np.sum(clnt_data_list) < 1600:
                        for cls_label in range(self.n_cls):
                            # Redraw class label if trn_y is out of that class
                            if cls_amount[cls_label] <= 0:
                                continue
                            cls_amount[cls_label] -= 1
                            clnt_x[curr_clnt][clnt_data_list[curr_clnt]] = trn_x[idx_list[cls_label][cls_amount[cls_label]]]
                            clnt_y[curr_clnt][clnt_data_list[curr_clnt]] = trn_y[idx_list[cls_label][cls_amount[cls_label]]]

                            break
                    else:
                        while True:
                            cls_label = np.argmax(np.random.uniform() <= curr_prior)
                            # Redraw class label if trn_y is out of that class
                            if cls_amount[cls_label] <= 0:
                                continue
                            cls_amount[cls_label] -= 1
                            clnt_x[curr_clnt][clnt_data_list[curr_clnt]] = trn_x[idx_list[cls_label][cls_amount[cls_label]]]
                            clnt_y[curr_clnt][clnt_data_list[curr_clnt]] = trn_y[idx_list[cls_label][cls_amount[cls_label]]]

                            break
                
                clnt_x = np.asarray(clnt_x)
                clnt_y = np.asarray(clnt_y)
                
                cls_means = np.zeros((self.n_client, self.n_cls))
                for clnt in range(self.n_client):
                    for cls in range(self.n_cls):
                        cls_means[clnt,cls] = np.mean(clnt_y[clnt]==cls)
                prior_real_diff = np.abs(cls_means-cls_priors)
                print('--- Max deviation from prior: %.4f' %np.max(prior_real_diff))
                print('--- Min deviation from prior: %.4f' %np.min(prior_real_diff))
            
            elif self.rule == 'iid' and self.dataset == 'CIFAR100' and self.unbalanced_sgm==0:
                assert len(trn_y)//100 % self.n_client == 0 
                idx = np.argsort(trn_y[:, 0])
                n_data_per_clnt = len(trn_y) // self.n_client
                # clnt_x dtype needs to be float32, the same as weights
                clnt_x = np.zeros((self.n_client, n_data_per_clnt, 3, 32, 32), dtype=np.float32)
                clnt_y = np.zeros((self.n_client, n_data_per_clnt, 1), dtype=np.float32)
                trn_x = trn_x[idx] # 50000*3*32*32
                trn_y = trn_y[idx]
                n_cls_sample_per_device = n_data_per_clnt // 100
                for i in range(self.n_client): # devices
                    for j in range(100): # class
                        clnt_x[i, n_cls_sample_per_device*j:n_cls_sample_per_device*(j+1), :, :, :] = trn_x[500*j+n_cls_sample_per_device*i:500*j+n_cls_sample_per_device*(i+1), :, :, :]
                        clnt_y[i, n_cls_sample_per_device*j:n_cls_sample_per_device*(j+1), :] = trn_y[500*j+n_cls_sample_per_device*i:500*j+n_cls_sample_per_device*(i+1), :] 
            
            
            elif self.rule == 'iid':
                clnt_x = [ np.zeros((clnt_data_list[clnt__], self.channels, self.height, self.width)).astype(np.float32) for clnt__ in range(self.n_client) ]
                clnt_y = [ np.zeros((clnt_data_list[clnt__], 1)).astype(np.int64) for clnt__ in range(self.n_client) ]
            
                clnt_data_list_cum_sum = np.concatenate(([0], np.cumsum(clnt_data_list)))
                for clnt_idx_ in range(self.n_client):
                    clnt_x[clnt_idx_] = trn_x[clnt_data_list_cum_sum[clnt_idx_]:clnt_data_list_cum_sum[clnt_idx_+1]]
                    clnt_y[clnt_idx_] = trn_y[clnt_data_list_cum_sum[clnt_idx_]:clnt_data_list_cum_sum[clnt_idx_+1]]
                
                
                clnt_x = np.asarray(clnt_x)
                clnt_y = np.asarray(clnt_y)

            
            self.clnt_x = clnt_x; self.clnt_y = clnt_y

            self.tst_x  = tst_x;  self.tst_y  = tst_y
            
            # Save data
            os.mkdir('%s/%s' %(self.data_path, self.name))
            
            np.save('%s/%s/clnt_x.npy' %(self.data_path, self.name), clnt_x)
            np.save('%s/%s/clnt_y.npy' %(self.data_path, self.name), clnt_y)

            np.save('%s/%s/tst_x.npy'  %(self.data_path, self.name),  tst_x)
            np.save('%s/%s/tst_y.npy'  %(self.data_path, self.name),  tst_y)

        else:
            print("Data is already downloaded in the folder.")
            self.clnt_x = np.load('%s/%s/clnt_x.npy' %(self.data_path, self.name), allow_pickle=True)
            self.clnt_y = np.load('%s/%s/clnt_y.npy' %(self.data_path, self.name), allow_pickle=True)
            self.n_client = len(self.clnt_x)

            self.tst_x  = np.load('%s/%s/tst_x.npy'  %(self.data_path, self.name), allow_pickle=True)
            self.tst_y  = np.load('%s/%s/tst_y.npy'  %(self.data_path, self.name), allow_pickle=True)
            
            if self.dataset == 'CIFAR100':
                self.channels = 3; self.width = 32; self.height = 32; self.n_cls = 100;
        count = 0
        for clnt in range(self.n_client):
            count += self.clnt_y[clnt].shape[0]
        
        
        print('Total Amount:%d' %count)
        print("Test" + ' Amount:%d' %self.tst_y.shape[0])


def DataLoader(opt):

    datapath = opt.data_path
    if opt.dataset == 'CIFAR10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        trainset = torchvision.datasets.CIFAR10(
            root=datapath, train=True, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=opt.batch, shuffle=True, num_workers=4)
        testset = torchvision.datasets.CIFAR10(
            root=datapath, train=False, download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=100, shuffle=False, num_workers=4)
    

    return trainloader, testloader


TOTAL_BAR_LENGTH = 65.
import time
import sys
_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)
last_time = time.time()
begin_time = last_time
def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f


def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()

def get_mdl_params(model_list, n_par=None):
    
    if n_par==None:
        exp_mdl = model_list[0]
        n_par = 0
        # for name, param in exp_mdl.named_parameters():
        for name, param in exp_mdl.state_dict().items():
            n_par += len(param.data.reshape(-1))
         
    param_mat = np.zeros((len(model_list), n_par)).astype('float32')
    for i, mdl in enumerate(model_list):
        idx = 0
        # for name, param in mdl.named_parameters():
        for name, param in mdl.state_dict().items():
            temp = param.data.cpu().numpy().reshape(-1)
            param_mat[i, idx:idx + len(temp)] = temp
            idx += len(temp)
    return np.copy(param_mat)

def transform_param(mdl, params, device):
    # dict_param = copy.deepcopy(dict(mdl.named_parameters()))
    dict_param = copy.deepcopy(mdl.state_dict())

    idx = 0
    # for name, param in mdl.named_parameters():
    for name, param in mdl.state_dict().items():
        weights = param.data
        length = len(weights.reshape(-1))
        dict_param[name].data.copy_(torch.tensor(params[idx:idx+length].reshape(weights.shape)).to(device))
        idx += length
    
    # mdl.load_state_dict(dict_param, strict=False)    
    return dict_param


def partition_dirichlet(Y, n_clients, alpha, seed):
    clients = []
    ex_per_class = np.unique(Y, return_counts=True)[1]
    n_classes = len(ex_per_class)
    print(f"Found {n_classes} classes")
    rv_tr = stats.dirichlet.rvs(np.repeat(alpha, n_classes), size=n_clients, random_state=seed) 
    rv_tr = rv_tr / rv_tr.sum(axis=0)
    rv_tr = (rv_tr*ex_per_class).round().astype(int)
    class_to_idx = {i: np.where(Y == i)[0] for i in range(n_classes)}
    curr_start = np.zeros(n_classes).astype(int)
    for client_classes in rv_tr:
        curr_end = curr_start + client_classes
        client_idx = np.concatenate([class_to_idx[c][curr_start[c]:curr_end[c]] for c in range(n_classes)])
        curr_start = curr_end
        clients.append(client_idx)
    return clients

