import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import RobertaTokenizer
from distill import distillation_loss

class Client:
    def __init__(self, client_id, model, train_dataset, public_dataset, device, args):
        self.client_id = client_id
        self.model = model
        self.train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        self.public_loader = DataLoader(public_dataset, batch_size=args.batch_size, shuffle=False)
        self.device = device
        self.args = args
        self.tokenizer = RobertaTokenizer.from_pretrained('roberta-large')

    def train(self, global_logits=None):
        self.model.to(self.device)
        self.model.train()
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.args.lr)

        for epoch in range(self.args.local_epochs):
            for batch in self.train_loader:
                inputs = self.tokenizer(batch['text'], padding=True, truncation=True, return_tensors='pt').to(self.device)
                labels = batch['label'].to(self.device)
                outputs = self.model(inputs['input_ids'])
                loss_task = nn.CrossEntropyLoss()(outputs, labels)

                # Distillation loss using public dataset logits
                if global_logits is not None:
                    for pub_batch, z_global in zip(self.public_loader, global_logits):
                        pub_inputs = self.tokenizer(pub_batch['text'], padding=True, truncation=True, return_tensors='pt').to(self.device)
                        z_local = self.model(pub_inputs['input_ids'])
                        loss_distill = distillation_loss(z_local, z_global.to(self.device), T=self.args.T)
                        loss = self.args.alpha * loss_task + (1 - self.args.alpha) * loss_distill
                        break
                else:
                    loss = loss_task

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        return self.model.state_dict()

    def evaluate(self, test_loader):
        self.model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for batch in test_loader:
                inputs = self.tokenizer(batch['text'], padding=True, truncation=True, return_tensors='pt').to(self.device)
                labels = batch['label'].to(self.device)
                outputs = self.model(inputs['input_ids'])
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        return correct / total
