#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from sklearn import metrics
import pdb
from tqdm import tqdm

class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):

        image, label = self.dataset[self.idxs[item]]
        return image, label


class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss().to(self.args.device)
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True, drop_last=True)

        
    def train(self, net):
        net.train()
        # train and update
        # add adam using arg.
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)

        epoch_loss = []
        for iter in tqdm(range(self.args.local_ep)):
            batch_loss = []
            see_label = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                # see_label.append(labels)
                

                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)

                loss.backward()
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))

        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)


class ServerUpdate(object):
    def __init__(self, args, dataset=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss().to(self.args.device)
        self.selected_clients = []
        self.ldr_train = DataLoader(dataset, batch_size=self.args.local_bs, shuffle=True, drop_last=True)

        
    def train(self, net):
        net.train()
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)

        epoch_loss = []
        for iter in tqdm(range(self.args.local_ep)):
            batch_loss = []
            see_label = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                # see_label.append(labels)
                

                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)

                loss.backward()
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))

        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)




class LocalUpdate_q_ourpre(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss().to(self.args.device)
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True, drop_last=True)

        
    def train(self, net):
        net.train()

        # optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        epoch_loss = []
        for iter in tqdm(range(1)):
            batch_loss = []
            see_label = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                # see_label.append(labels)
                

                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)

                # loss.backward()
                # optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))

        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)


class LocalUpdate_qfair(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss().to(self.args.device)
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True, drop_last=True)

        
    def train(self, net):
        net.train()
        # train and update
        # add adam using arg.
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)

        epoch_loss = []
        for iter in tqdm(range(self.args.local_ep)):
            batch_loss = []
            see_label = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                # see_label.append(labels)
                

                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                
                loss.backward()
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        # for k, v in zip(net.state_dict(), net.parameters()):
        #     if v.grad == None:
        #         pdb.set_trace()
        
        Grad = {k:v.grad for k, v in zip(net.state_dict(), net.parameters())}

        # pdb.set_trace()
        # Grad = {k:v.grad for k,v in net.named_parameters()}

        return net.state_dict(), sum(epoch_loss) / len(epoch_loss), Grad






class LocalUpdate_meta_s(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss().to(self.args.device)
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True, drop_last=True)

        
    def train(self, net):
        net.train()
        # train and update
        # add adam using arg.
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)

        epoch_loss = []
        for iter in tqdm(range(self.args.local_ep)):
            batch_loss = []
            see_label = []
            correct = 0
            support_loss, support_correct, support_num_sample = [], [], []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                # see_label.append(labels)
                num_sample = labels.size(0)

                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                

                loss.backward()
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
                
                y_pred = log_probs.data.max(1, keepdim=True)[1]
                correct = y_pred.eq(labels.data.view_as(y_pred)).long().cpu().sum()
                support_correct.append(correct.item()) 

                support_loss.append(loss.item())
                support_num_sample.append(num_sample)


            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss) , support_loss, support_correct, support_num_sample

class LocalUpdate_meta_q(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss().to(self.args.device)
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True, drop_last=True)

    def train(self, net, s_loss, s_correct, s_num_sample):
        net.train()
        for iter in range(1): # one epoch meta learning
            batch_loss = []
            see_label = []
            query_loss = []
            query_correct = []
            query_num_sample = []
            loss_sum = 0.0
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                
                num_sample = labels.size(0)
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                query_loss.append(loss.item())
                query_num_sample.append(num_sample)

                loss_sum += loss * num_sample

                y_pred = log_probs.data.max(1, keepdim=True)[1]
                correct = y_pred.eq(labels.data.view_as(y_pred)).long().cpu().sum()
                query_correct.append(correct.item()) 
                


            spt_sz = np.sum(s_num_sample)
            qry_sz = np.sum(query_num_sample)
            grads = torch.autograd.grad(loss_sum / qry_sz, list(net.parameters()))
            # len(net.state_dict()) 122
            # len(list(net.parameters())) 62


            for p in net.parameters():
                if p.grad is not None:
                    p.grad.zero_()
                    
        # return grads
        return {
            'support_loss_sum': np.dot(s_loss, s_num_sample),
            'query_loss_sum': np.dot(query_loss, query_num_sample),
            'support_correct': np.sum(s_correct),
            'query_correct': np.sum(query_correct),
            'support_num_samples': spt_sz,
            'query_num_samples': qry_sz,
        }, grads


class LocalUpdate_meta_q_our(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss().to(self.args.device)
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True, drop_last=True)

    def train(self, net, s_loss, s_correct, s_num_sample, balance, std):
        net.train()
        for iter in range(1): # one epoch meta learning
            batch_loss = []
            see_label = []
            query_loss = []
            query_correct = []
            query_num_sample = []
            loss_sum = 0.0
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                
                num_sample = labels.size(0)
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                query_loss.append(loss.item())
                query_num_sample.append(num_sample)

                loss_sum += loss * num_sample

                y_pred = log_probs.data.max(1, keepdim=True)[1]
                correct = y_pred.eq(labels.data.view_as(y_pred)).long().cpu().sum()
                query_correct.append(correct.item()) 
                


            spt_sz = np.sum(s_num_sample)
            qry_sz = np.sum(query_num_sample)
            L_cust = balance * loss_sum + (1 - balance) * std
            # grads = torch.autograd.grad(loss_sum / qry_sz, list(net.parameters()))
            grads = torch.autograd.grad(L_cust / qry_sz, list(net.parameters()))
            

            for p in net.parameters():
                if p.grad is not None:
                    p.grad.zero_()
                    
        # return grads
        return {
            'support_loss_sum': np.dot(s_loss, s_num_sample),
            'query_loss_sum': np.dot(query_loss, query_num_sample),
            'support_correct': np.sum(s_correct),
            'query_correct': np.sum(query_correct),
            'support_num_samples': spt_sz,
            'query_num_samples': qry_sz,
        }, grads




class LocalUpdate_meta_q_our_3(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss().to(self.args.device)
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True, drop_last=True)

    def train(self, net, s_loss, s_correct, s_num_sample, balance, loss_eval):
        net.train()
        for iter in range(1): # one epoch meta learning
            batch_loss = []
            see_label = []
            query_loss = []
            query_correct = []
            query_num_sample = []
            loss_sum = 0.0
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                
                num_sample = labels.size(0)
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                query_loss.append(loss.item())
                query_num_sample.append(num_sample)

                loss_sum += loss * num_sample

                y_pred = log_probs.data.max(1, keepdim=True)[1]
                correct = y_pred.eq(labels.data.view_as(y_pred)).long().cpu().sum()
                query_correct.append(correct.item()) 
                


            spt_sz = np.sum(s_num_sample)
            qry_sz = np.sum(query_num_sample)

            # # our 3
            # L_cust = loss_sum *  np.float_power(loss_eval, 1-balance)
            
            # # our 3_1
            L_cust = loss_sum *  np.float_power(loss_eval, balance)

            # grads = torch.autograd.grad(loss_sum / qry_sz, list(net.parameters()))
            grads = torch.autograd.grad(L_cust / qry_sz, list(net.parameters()))
            

            for p in net.parameters():
                if p.grad is not None:
                    p.grad.zero_()
                    
        # return grads
        return {
            'support_loss_sum': np.dot(s_loss, s_num_sample),
            'query_loss_sum': np.dot(query_loss, query_num_sample),
            'support_correct': np.sum(s_correct),
            'query_correct': np.sum(query_correct),
            'support_num_samples': spt_sz,
            'query_num_samples': qry_sz,
        }, grads

class LocalUpdate_meta_q_our_4(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss().to(self.args.device)
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True, drop_last=True)

    def train(self, net, s_loss, s_correct, s_num_sample):
        net.train()
        for iter in range(1): # one epoch meta learning
            batch_loss = []
            see_label = []
            query_loss = []
            query_correct = []
            query_num_sample = []
            loss_sum = 0.0
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                
                num_sample = labels.size(0)
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                query_loss.append(loss.item())
                query_num_sample.append(num_sample)

                loss_sum += loss * num_sample

                y_pred = log_probs.data.max(1, keepdim=True)[1]
                correct = y_pred.eq(labels.data.view_as(y_pred)).long().cpu().sum()
                query_correct.append(correct.item()) 
                


            spt_sz = np.sum(s_num_sample)
            qry_sz = np.sum(query_num_sample)
            grads = torch.autograd.grad(loss_sum / qry_sz, list(net.parameters()))
            # len(net.state_dict()) 122
            # len(list(net.parameters())) 62


            for p in net.parameters():
                if p.grad is not None:
                    p.grad.zero_()
                    
        # return grads
        return {
            'support_loss_sum': np.dot(s_loss, s_num_sample),
            'query_loss_sum': np.dot(query_loss, query_num_sample),
            'support_correct': np.sum(s_correct),
            'query_correct': np.sum(query_correct),
            'support_num_samples': spt_sz,
            'query_num_samples': qry_sz,
        }, grads, loss_sum



class Adam:
    def __init__(self, lr=0.01, betas=(0.9, 0.999), eps=1e-08):
        """
        :param lr:
        :param betas:
        :param eps:
        """
        self.lr = lr
        self.beta1 = betas[0]
        self.beta2 = betas[1]
        self.eps = eps
        self.m = dict()
        self.v = dict()
        self.n = 0
        self.creted_momtem_grad_index = set()

    def __call__(self, params, grads, i):

        if i not in self.m:
            self.m[i] = torch.zeros_like(params)
        if i not in self.v:
            self.v[i] = torch.zeros_like(params)

        self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * grads
        self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * torch.square(grads)

        alpha = self.lr * np.sqrt(1 - np.power(self.beta2, self.n))
        alpha = alpha / (1 - np.power(self.beta1, self.n))

        params.sub_(alpha * self.m[i] / (torch.sqrt(self.v[i]) + self.eps))

    def increase_n(self):
        self.n += 1