import copy
from math import ceil
from warnings import catch_warnings
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset,Subset
from core.function import gather_flat_grad, get_trainable_hyper_params, loss_adjust_cross_entropy, gather_flat_hyper_params
from utils.svrg import SVRG_Snapshot
from numpy import random
from torch.autograd import grad
import torch.nn.functional as F
from torch.optim import SGD
from core.function import assign_hyper_gradient
from itertools import cycle

from collections.abc import Sequence
import numpy as np

def combine_idxs(a, b):
    if isinstance(a, np.ndarray):
        return np.concatenate([a, b])
    elif isinstance(a, Sequence):  # covers list, tuple
        return a + b
    else:
        raise TypeError(f"Unsupported type for idxs: {type(a)}")



def corrupt_labels(labels, p, num_classes):

    labels_tensor = torch.tensor(labels)
    

    if labels_tensor.dim() == 0:

        labels_tensor = labels_tensor.unsqueeze(0)
    elif labels_tensor.dim() != 1:
        raise ValueError("labels must be a 1-dimensional tensor")

    corrupted_labels = labels_tensor.clone()
    
    for i in range(len(labels_tensor)):
        if random.random() < p:
            corrupted_labels[i] = random.randint(0, num_classes - 1)
    
    if isinstance(labels, list):
        return corrupted_labels.tolist()
    elif isinstance(labels, torch.Tensor):
        return corrupted_labels
    elif isinstance(labels, (int, float)):
        return int(corrupted_labels[0])
    else:
        return corrupted_labels.numpy()
    
class DatasetSplit_t(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        label = corrupt_labels(label, 0.3, 10)
        return image, label

class DatasetSplit_v(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

class Client():
    def __init__(self, args, client_id, net, dataset=None, idxs=None, param = None, hyper_param = None, theta = None) -> None:
        self.client_id = client_id
        self.args = args
        self.net = net
        self.init_net = copy.deepcopy(net)
        self.net.zero_grad()
        self.init_net.zero_grad()

        shuffled_idxs = random.permutation(idxs[client_id])
        batch_size = int(0.006 * len(shuffled_idxs))  
        batches = [shuffled_idxs[i:i + batch_size] for i in range(0, len(shuffled_idxs), batch_size)]

        combined_idxs = combine_idxs(idxs[client_id], idxs[-client_id - 1])

        self.idx_cycle = cycle(batches)

        current_idxs = next(self.idx_cycle)
        
        self.ldr_train = DataLoader(
            Subset(DatasetSplit_t(dataset, current_idxs), range(len(current_idxs))),
            batch_size=self.args.local_bs, shuffle=True)

        shuffled_val_idxs = random.permutation(idxs[-client_id - 1])
        val_batch_size = batch_size 
        val_batches = [shuffled_val_idxs[i:i + val_batch_size] for i in range(0, len(shuffled_val_idxs), val_batch_size)]

        self.val_idx_cycle = cycle(val_batches)

        current_val_idxs = next(self.val_idx_cycle)

        self.ldr_val = DataLoader(
            Subset(DatasetSplit_v(dataset, current_val_idxs), range(len(current_val_idxs))),
            batch_size=self.args.local_bs, shuffle=True
        )
        
        shuffled_sl_idxs = random.permutation(combined_idxs)
        sl_batch_size = batch_size 
        sl_batches = [shuffled_sl_idxs[i:i + val_batch_size] for i in range(0, len(shuffled_sl_idxs), sl_batch_size)]

        self.sl_idx_cycle = cycle(sl_batches)

        current_sl_idxs = next(self.sl_idx_cycle)

        self.ldr_sl = DataLoader(
            Subset(DatasetSplit_t(dataset, current_sl_idxs), range(len(current_sl_idxs))),
            batch_size=self.args.local_bs, shuffle=True
        )    
        
        self.theta = theta
        self.param = param
        self.hyper_param = hyper_param
        
        self.beta = 0.001
        self.loss_func = self.loss_adjust_cross_entropy #nn.CrossEntropyLoss()
        self.val_loss = self.cross_entropy
        self.loss_func1 = self.loss_cross_entropy
        self.bs = args.local_bs
          


    def grad_d_out_d_y(self, net = None, hyper_param = None):
        if net == None:
            self.net0 = copy.deepcopy(self.net)
        else:
            self.net0 = copy.deepcopy(net)
        if hyper_param == None:
            hyper_param = self.hyper_param 
        self.net0.train()
        params = list(self.net0.parameters())
        d_out_d_y = [0 for i in range(len(params))]
        for batch_idx, (images, labels) in enumerate(self.ldr_val):
            images, labels = images.to(
                self.args.device), labels.to(self.args.device)
            self.net0.zero_grad()
            log_probs = self.net0(images)
            loss = self.val_loss(log_probs, labels)
            grad= torch.autograd.grad(loss, params)
            for i in range(len(params)):
                d_out_d_y[i]+=grad[i]
        return [x / (batch_idx+1) for x in d_out_d_y]

    def grad_d_in_d_y(self, net = None, hyper_param = None):
        if net == None:
            self.net0 = copy.deepcopy(self.net)
        else:
            self.net0 = copy.deepcopy(net)
        if hyper_param == None:
            hyper_param = self.hyper_param 
        self.net0.train()
        params = list(self.net0.parameters())
        d_in_d_y = [0 for i in range(len(params))]
        for batch_idx, (images, labels) in enumerate(self.ldr_train):
            hyper_params_b = hyper_param[0+self.bs *batch_idx:self.bs +self.bs *batch_idx]
            images, labels = images.to(
                self.args.device), labels.to(self.args.device)
            self.net0.zero_grad()
            log_probs = self.net0(images)
            loss = self.loss_func(log_probs, labels,hyper_params_b)
            grad= torch.autograd.grad(loss, params)
            for i in range(len(params)):
                d_in_d_y[i]+=grad[i]
        return [x / (batch_idx+1) for x in d_in_d_y]
    
    
    def grad_d_t(self, net = None):
        if net == None:
            self.net0 = copy.deepcopy(self.net)
        else:
            self.net0 = copy.deepcopy(net)
        self.net0.train()
        params = list(self.net0.parameters())
        d_in_d_y = [0 for i in range(len(params))]
        for batch_idx, (images, labels) in enumerate(self.ldr_sl):
            images, labels = images.to(
                self.args.device), labels.to(self.args.device)
            self.net0.zero_grad()
            labels = corrupt_labels(labels, 0.6, 10)
            log_probs = self.net0(images)
            loss = self.loss_func1(log_probs, labels)
            grad= torch.autograd.grad(loss, params)
            for i in range(len(params)):
                d_in_d_y[i]+=grad[i]
        return [x / (batch_idx+1) for x in d_in_d_y]
    

    def hvp_iter(self, p, lr):
        if self.hyper_iter == 0:
            self.d_in_d_y = self.grad_d_in_d_y()
            self.counter = p.clone()
        old_counter = self.counter
        hessian_term = gather_flat_grad(
            grad(self.d_in_d_y, self.net0.parameters(),
                 grad_outputs=self.counter.view(-1), retain_graph=True)
        )
        self.counter = old_counter - lr * hessian_term
        p = p+self.counter
        self.hyper_iter += 1
        return p
    
    


    def grad_d_in_d_x(self, net = None, hyper_param = None):
        if net == None:
            self.net0 = copy.deepcopy(self.net)
        else:
            self.net0 = copy.deepcopy(net)
        if hyper_param == None:
            hyper_param = self.hyper_param 
        self.net0.train()
        hp = get_trainable_hyper_params(hyper_param)
        d_in_d_x = [0 for i in range(len(hp))]
        for batch_idx, (images, labels) in enumerate(self.ldr_train):
            hyper_params_b = hyper_param[0+self.bs*batch_idx:self.bs +self.bs *batch_idx]
            images, labels = images.to(
                self.args.device), labels.to(self.args.device)
            self.net0.zero_grad()
            log_probs = self.net0(images)
            loss = self.loss_func(log_probs, labels, hyper_params_b)
            grad= torch.autograd.grad(loss, hyper_params_b)
            #print('===============================len(grad[0]):', len(grad[0]))
            for i in range(self.bs):
                d_in_d_x[i+self.bs *batch_idx]=grad[0][i]
        return d_in_d_x



    def hyper_grad(self, p):
        d_in_d_y=self.grad_d_in_d_y()
        indirect_grad= gather_flat_grad(
            grad(d_in_d_y,
                get_trainable_hyper_params(self.hyper_param),
                grad_outputs= p.view(-1),
                allow_unused= True)
        )
        try:
            direct_grad= self.grad_d_out_d_x()
            hyper_grad=direct_grad-indirect_grad
        except:
            hyper_grad=-indirect_grad
        return hyper_grad

    def hyper_update(self, hg):
        assign_hyper_gradient(self.hyper_param, hg)
        self.hyper_optimizer.step()
        return -gather_flat_hyper_params(self.hyper_param)+gather_flat_hyper_params(self.hyper_param_init)
    
    def hyper_svrg_update(self, hg):
        try:
            direct_grad = self.grad_d_out_d_x()
            direct_grad_0 = self.grad_d_out_d_x(hyper_param=self.hyper_param_init)
            h = direct_grad - direct_grad_0 + hg
        except:
            h = hg
        assign_hyper_gradient(self.hyper_param, h)
        self.hyper_optimizer.step()
        return -gather_flat_hyper_params(self.hyper_param)+gather_flat_hyper_params(self.hyper_param_init)
    
    def hyper_sgd_update(self, hg):
        try:
            direct_grad = self.grad_d_out_d_x()
            h = direct_grad + hg
        except:
            h = hg
        assign_hyper_gradient(self.hyper_param, h)
        self.hyper_optimizer.step()
        return -gather_flat_hyper_params(self.hyper_param)+gather_flat_hyper_params(self.hyper_param_init)


    def loss_adjust_cross_entropy(self, logits, targets, hyper_params):
        reg = self.beta * sum([torch.norm(k) for k in list(self.net0.parameters())])
        loss =torch.mean(torch.sigmoid(hyper_params)* F.cross_entropy(logits, targets))+reg
        return loss
    
    def cross_entropy(self, logits, targets):
        return F.cross_entropy(logits, targets)
    
    def loss_cross_entropy(self, logits, targets):
        reg = sum([torch.norm(k) for k in list(self.net0.parameters())])
        loss =F.cross_entropy(logits, targets)+0.001*reg
        return loss
    
    


