import copy
from math import ceil
from warnings import catch_warnings
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from core.function import gather_flat_grad, get_trainable_hyper_params, loss_adjust_cross_entropy, gather_flat_hyper_params
from utils.svrg import SVRG_Snapshot
from numpy import random
from torch.autograd import grad
import torch.nn.functional as F
from torch.optim import SGD
from core.function import assign_hyper_gradient

class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)
        

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label


class Client():
    def __init__(self, args, client_id, net, dataset=None, idxs=None, param = None, hyper_param = None, theta = None) -> None:
        self.client_id = client_id
        self.args = args
        self.net = net
        self.init_net = copy.deepcopy(net)
        self.net.zero_grad()
        self.init_net.zero_grad()
        self.beta = 0

        self.ldr_train = DataLoader(DatasetSplit(
            dataset, idxs[client_id]), batch_size=self.args.local_bs, shuffle=True)
        self.ldr_val = DataLoader(DatasetSplit(
            dataset, idxs[-client_id - 1]), batch_size=self.args.local_bs, shuffle=True)

        self.theta = theta
        self.param = param
        self.hyper_param = hyper_param
        self.val_loss = self.cross_entropy
        self.loss_func = self.cross_entropy_reg #nn.CrossEntropyLoss()
        self.hyper_iter = 0

    def grad_d_in_d_y(self, net=None):
        if net is None:
            self.net0 = copy.deepcopy(self.net)
        else:
            self.net0 = copy.deepcopy(net)
        self.net0.train()
        params = [k for n, k in self.net0.named_parameters() if "header" in n]
        d_in_d_y = [0 for i in range(len(self.param))]
        for batch_idx, (images, labels) in enumerate(self.ldr_train):
            images, labels = images.to(
                self.args.device), labels.to(self.args.device)
            self.net0.zero_grad()
            log_probs = self.net0(images)
            loss = self.loss_func(log_probs, labels, params)
            grad = torch.autograd.grad(loss, params)
            for i in range(len(self.param)):
                d_in_d_y[i] += grad[i]

        return [x / (batch_idx+1) for x in d_in_d_y]

    #def grad_d_out_d_y(self):
    #    self.net0.train()
    #    params = [k for n,k in self.net0.named_parameters() if "header" in n]
    #    d_in_d_y = [0 for i in range(len(self.param))]
    #    for batch_idx, (images, labels) in enumerate(self.ldr_val):
    #        images, labels = images.to(
    #            self.args.device), labels.to(self.args.device)
    #        self.net0.zero_grad()
    #        log_probs = self.net0(images)
    #        loss = self.val_loss(log_probs, labels)
    #        grad = torch.autograd.grad(loss, params, create_graph=True)
    #        for i in range(len(self.param)):
    #            d_in_d_y[i] += grad[i]

    #    return [x / (batch_idx+1) for x in d_in_d_y]
    

    def grad_d_out_d_y(self, net=None):
        if net is None:
            self.net0 = copy.deepcopy(self.net)
        else:
            self.net0 = copy.deepcopy(net)
        self.net0.train()
        params = [k for n, k in self.net0.named_parameters() if "header" in n]
        #num_weights = sum(p.numel() for p in params)
        params = [k for n,k in self.net0.named_parameters() if "header" in n]
        d_out_d_y = [0 for i in range(len(self.param))]
        for batch_idx, (images, labels) in enumerate(self.ldr_val):
            images, labels = images.to(
                self.args.device), labels.to(self.args.device)
            self.net0.zero_grad()
            log_probs = self.net0(images)
            loss = self.val_loss(log_probs, labels)
            grad = torch.autograd.grad(loss, params)
            for i in range(len(self.param)):
                d_out_d_y[i] += grad[i]

        return [x / (batch_idx+1) for x in d_out_d_y]
    
    def grad_d_out_d_x(self, net = None):
        if net == None:
            net = copy.deepcopy(self.net)
        else:
            net = copy.deepcopy(net)
        net.train()
        hyper_param = [k for n,k in net.named_parameters() if not "header" in n]
        
        # num_weights = sum(p.numel() for p in hyper_param)
        
        d_out_d_x = [0 for i in range(len(hyper_param))]
        for batch_idx, (images, labels) in enumerate(self.ldr_val):
           
            images, labels = images.to(
                self.args.device), labels.to(self.args.device)
            net.zero_grad()
            log_probs = net(images)
            loss = self.val_loss(log_probs, labels)
            grad = torch.autograd.grad(loss, hyper_param)
            for i in range(len(hyper_param)):
                d_out_d_x[i] += grad[i]
        return [x / (batch_idx+1) for x in d_out_d_x]

    def grad_d_in_d_x(self, net = None):
        if net == None:
            net = copy.deepcopy(self.net)
        else:
            net = copy.deepcopy(net)
        net.train()
        hyper_param = [k for n,k in net.named_parameters() if not "header" in n]
        params = [k for n,k in net.named_parameters() if  "header" in n]
        # num_weights = sum(p.numel() for p in hyper_param)
        
        d_in_d_x = [0 for i in range(len(hyper_param))]
        for batch_idx, (images, labels) in enumerate(self.ldr_train):
           
            images, labels = images.to(
                self.args.device), labels.to(self.args.device)
            net.zero_grad()
            log_probs = net(images)
            loss = self.loss_func(log_probs, labels, params)
            grad = torch.autograd.grad(loss, hyper_param)
            for i in range(len(hyper_param)):
                d_in_d_x[i] += grad[i]
        return [x / (batch_idx+1) for x in d_in_d_x]


    



    # def grad_d_out_d_y(self):
    #     # self.net0 = copy.deepcopy(self.net)
    #     self.net.train()
    #     # y
    #     params = copy.deepcopy(self.param)
    #     d_in_d_y = [0 for i in range(len(self.param))]
    #     for batch_idx, (images, labels) in enumerate(self.ldr_val):
    #         images, labels = images.to(
    #             self.args.device), labels.to(self.args.device)
    #         self.net.zero_grad()
    #         log_probs = self.net(images)
    #         loss = self.val_loss(log_probs, labels)
    #         for i in range(len(self.param)):
    #             d_in_d_y[i] += torch.autograd.grad(loss, params, create_graph=True)[i]
    #
    #     return [x / (batch_idx+1) for x in d_in_d_y]

    #def grad_theta_R(self):
    #    d_y_d_f = self.grad_d_out_d_y()
    #    d_y_d_g = self.grad_d_in_d_y()
    #    gradthetaR = []

    #    temp_vector = d_y_d_g[0]
    #    temp_vector = temp_vector.reshape(1, -1)
    #    temp_theta = self.theta[0]
    #    temp_theta = temp_theta.reshape(1, -1)
      
    #    for i in range(1, len(d_y_d_g)):
    #        temp_theta = torch.cat((temp_theta, self.theta[i].reshape(1, -1)), 1)
    #        temp_vector = torch.cat((temp_vector, d_y_d_g[i].reshape(1, -1)), 1)
    #    mul = temp_vector @ temp_theta.T
    #    params = [k for n, k in self.net0.named_parameters() if "header" in n]
    #    hessian_vector = torch.autograd.grad(mul, params)
    #    for i in range(len(d_y_d_g)):
    #        gradthetaR.append(hessian_vector[i] - d_y_d_f[i])
    #    return gradthetaR

    #def grad_f_bar(self):
    #    d_out_d_x = self.grad_d_out_d_x()
    #    d_in_d_y = self.grad_d_in_d_y()
    #    temp_vector = d_in_d_y[0]
    #    temp_vector = temp_vector.reshape(1, -1)
    #    temp_theta = self.theta[0]
        # temp_theta = torch.flatten(temp_theta)
    #     temp_theta = temp_theta.reshape(1, -1)
    #    hyper_param = [k for n, k in self.net0.named_parameters() if not "header" in n]

    #    f_bar = []
        # a = 1
    #    for i in range(1, len(d_in_d_y)):
    #        temp_theta = torch.cat((temp_theta, self.theta[i].reshape(1, -1)), 1)
    #        temp_vector = torch.cat((temp_vector, d_in_d_y[i].reshape(1, -1)), 1)
    #    mul = temp_vector @ temp_theta.T
    #    hessian_vector = torch.autograd.grad(mul, hyper_param)
    #    for i in range(len(d_out_d_x)):
    #        f_bar.append(d_out_d_x[i] - hessian_vector[i])
    #    return f_bar

    # def grad(self, p):
    #     if self.hyper_iter == 0:
    #         self.d_in_d_y = self.grad_d_in_d_y()
    #         self.counter = p.clone()
    #         # self.counter = p
    #     # y
    #     params = [k for n, k in self.net0.named_parameters() if "header" in n]
    #     params[0].requires_grad = True
    #     params[1].requires_grad = True
    #     old_counter = self.counter
    #     hessian_term = gather_flat_grad(
    #         grad(self.d_in_d_y[0], params,
    #              grad_outputs=self.counter.view(-1), retain_graph=True)
    #     )
    #     self.counter = old_counter - hessian_term
    #     p = self.counter
    #     self.hyper_iter += 1
    #     return p
    
    #def grad_d_out_d_x(self, net = None):
    #    if net == None:
    #        net = copy.deepcopy(self.net)
    #    else:
    #        net = copy.deepcopy(net)
    #    net.train()
    #    hyper_param = [k for n,k in net.named_parameters() if not "header" in n]

        # num_weights = sum(p.numel() for p in hyper_param)

    #    d_out_d_x = [0 for i in range(len(hyper_param))]
    #    for batch_idx, (images, labels) in enumerate(self.ldr_val):
    #        images, labels = images.to(
    #            self.args.device), labels.to(self.args.device)
    #        net.zero_grad()
    #        log_probs = net(images)
    #        loss = self.val_loss(log_probs, labels)
    #        for i in range(len(hyper_param)):
    #            d_out_d_x[i] += torch.autograd.grad(loss, hyper_param, create_graph=True)[i]

    #    return [x / (batch_idx+1) for x in d_out_d_x]
    
    
    
    # def hyper_grad(self, p):
    #     d_in_d_y, hyper_param=self.grad_d_in_d_y()
    #     # hyper_param[0].requires_grad = True
    #     # hyper_param[1].requires_grad = True
    #     indirect_grad = gather_flat_grad(
    #         grad(d_in_d_y,
    #             get_trainable_hyper_params(hyper_param),
    #             grad_outputs= p.view(-1),
    #             allow_unused= True)
    #     )
    #     try:
    #         direct_grad= self.grad_d_out_d_x()
    #         hyper_grad=direct_grad-self.args.hlr*indirect_grad
    #     except:
    #         #print(" No direct grad, use only indirect gradient.")
    #         hyper_grad=-indirect_grad
    #     return hyper_grad
    #
    # def hyper_update(self, hg):
    #
    #     assign_hyper_gradient(self.hyper_param, hg.detach())
    #     self.hyper_optimizer.step()
    #     return -gather_flat_hyper_params(self.hyper_param)+gather_flat_hyper_params(self.hyper_param_init)
    #
    # def hyper_svrg_update(self, hg):
    #     try:
    #         direct_grad = self.grad_d_out_d_x()
    #         direct_grad_0 = self.grad_d_out_d_x(net=self.init_net)
    #         h = direct_grad - direct_grad_0 + hg
    #     except:
    #         h = hg
    #     h=h.detach()
    #     assign_hyper_gradient(self.hyper_param, h)
    #     self.hyper_optimizer.step()
    #     return -gather_flat_hyper_params(self.hyper_param)+gather_flat_hyper_params(self.hyper_param_init)
    
    def cross_entropy(self, logits, targets):
        return F.cross_entropy(logits, targets)
        
    def cross_entropy_reg(self, logits, targets, param):
        reg = self.beta*sum([torch.norm(k) for k in param])
        return F.cross_entropy(logits, targets)+0.5*reg



