import copy
import torch
import torch.nn as nn
import numpy as np
import os
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.preprocessing import label_binarize
from sklearn import metrics
from utils.data_utils import read_client_data
from flcore.trainmodel.models import BaseHeadSplit


class Client(object):
    """
    Base class for clients in federated learning.
    Handles local model training, evaluation, and communication with the server.
    """

    def __init__(self, args, id, train_samples, test_samples, **kwargs):
        """
        Initialize a client with experiment parameters and data statistics.
        
        Args:
            args (argparse.Namespace): Experiment configurations
            id (int): Client ID
            train_samples (int): Number of training samples
            test_samples (int): Number of testing samples
           ** kwargs: Additional arguments (train_slow, send_slow)
        """
        torch.manual_seed(0)
        self.algorithm = args.algorithm
        self.dataset = args.dataset
        self.device = args.device
        self.id = id  # integer
        self.role = 'Client_' + str(self.id)
        self.save_folder_name = args.save_folder_name_full

        self.num_classes = args.num_classes
        self.train_samples = train_samples
        self.test_samples = test_samples
        self.batch_size = args.batch_size
        self.learning_rate = args.local_learning_rate
        self.local_epochs = args.local_epochs

        if args.save_folder_name == 'temp' or 'temp' not in args.save_folder_name:
            model = BaseHeadSplit(args, self.id).to(self.device)
            save_item(model, self.role, 'model', self.save_folder_name)

        self.train_slow = kwargs['train_slow']
        self.send_slow = kwargs['send_slow']
        self.train_time_cost = {'num_rounds': 0, 'total_cost': 0.0}
        self.send_time_cost = {'num_rounds': 0, 'total_cost': 0.0}

        self.loss = nn.CrossEntropyLoss()


    def load_train_data(self, batch_size=None):
        """
        Load training data as a DataLoader.
        
        Args:
            batch_size (int, optional): Batch size (uses client's default if None)
            
        Returns:
            DataLoader: Training data loader
        """
        if batch_size == None:
            batch_size = self.batch_size
        train_data = read_client_data(self.dataset, self.id, is_train=True)
        return DataLoader(train_data, batch_size, drop_last=True, shuffle=False)

    def load_test_data(self, batch_size=None):
        """
        Load testing data as a DataLoader.
        
        Args:
            batch_size (int, optional): Batch size (uses client's default if None)
            
        Returns:
            DataLoader: Testing data loader
        """
        if batch_size == None:
            batch_size = self.batch_size
        test_data = read_client_data(self.dataset, self.id, is_train=False)
        return DataLoader(test_data, batch_size, drop_last=False, shuffle=False)

    def clone_model(self, model, target):
        """
        Clone parameters from one model to another.
        
        Args:
            model (torch.nn.Module): Source model
            target (torch.nn.Module): Target model to update
        """
        for param, target_param in zip(model.parameters(), target.parameters()):
            target_param.data = param.data.clone()
            # target_param.grad = param.grad.clone()

    def update_parameters(self, model, new_params):
        """
        Update a model's parameters with new parameters.
        
        Args:
            model (torch.nn.Module): Model to update
            new_params (iterable): New parameters to assign
        """
        for param, new_param in zip(model.parameters(), new_params):
            param.data = new_param.data.clone()

    def test_metrics(self):
        """
        Evaluate the client's local model on its test data.
        
        Computes test accuracy and AUC (Area Under ROC Curve).
        
        Returns:
            tuple: (total_correct, total_samples, auc_score)
        """
        testloaderfull = self.load_test_data()
        # model.to(self.device)
        model = load_item(self.role, 'model', self.save_folder_name)
        model.eval()

        test_acc = 0
        test_num = 0
        y_prob = []
        y_true = []
        
        with torch.no_grad():
            for x, y in testloaderfull:
                if type(x) == type([]):
                    x[0] = x[0].to(self.device)
                else:
                    x = x.to(self.device)
                y = y.to(self.device)
                output = model(x)

                test_acc += (torch.sum(torch.argmax(output, dim=1) == y)).item()
                test_num += y.shape[0]

                y_prob.append(output.detach().cpu().numpy())
                nc = self.num_classes
                if self.num_classes == 2:
                    nc += 1
                lb = label_binarize(y.detach().cpu().numpy(), classes=np.arange(nc))
                if self.num_classes == 2:
                    lb = lb[:, :2]
                y_true.append(lb)

        y_prob = np.concatenate(y_prob, axis=0)
        y_true = np.concatenate(y_true, axis=0)

        auc = metrics.roc_auc_score(y_true, y_prob, average='micro')
        
        return test_acc, test_num, auc

    def train_metrics(self):
        """
        Evaluate the client's local model on its training data.
        
        Computes average training loss.
        
        Returns:
            tuple: (total_loss, total_samples)
        """
        trainloader = self.load_train_data()
        # model.to(self.device)
        model = load_item(self.role, 'model', self.save_folder_name)
        model.eval()

        train_num = 0
        losses = 0
        with torch.no_grad():
            for x, y in trainloader:
                if type(x) == type([]):
                    x[0] = x[0].to(self.device)
                else:
                    x = x.to(self.device)
                y = y.to(self.device)
                output = model(x)
                loss = self.loss(output, y)
                train_num += y.shape[0]
                losses += loss.item() * y.shape[0]

        return losses, train_num

def save_item(item, role, item_name, item_path=None):
    if not os.path.exists(item_path):
        os.makedirs(item_path)
    torch.save(item, os.path.join(item_path, role + "_" + item_name + ".pt"))

def load_item(role, item_name, item_path=None):
    try:
        return torch.load(os.path.join(item_path, role + "_" + item_name + ".pt"))
    except FileNotFoundError:
        print(role, item_name, 'Not Found')
        return None
