import logging
import os
from typing import Any, Callable, List, Literal, Optional, Tuple, Union

import numpy as np

import torch
from torch import nn
import pickle
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer, DistilBertModel, AutoTokenizer
from torch.nn import CrossEntropyLoss
from transformers.models.vit.configuration_vit import ViTConfig
from peft import LoraConfig, get_peft_model
from sklearn.metrics import accuracy_score

from clients.base import Client
from utils.model_ops import print_trainable_parameters




class BERTClassifier(nn.Module):
    def __init__(self, bert_model_name, num_classes, proj_dim):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)

        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
        self.projection_layer = nn.Linear(self.bert.config.hidden_size, proj_dim)

    def forward(self, input_ids, attention_mask, projection=False):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        if not projection:
            logits = self.classifier(pooled_output)
        else:
            logits = self.projection_layer(pooled_output)
        return logits
    
class DistilBERTClassifier(nn.Module):
    def __init__(self, distilbert_model_name, num_classes, proj_dim):
        super(DistilBERTClassifier, self).__init__()
        self.distilbert = DistilBertModel.from_pretrained(distilbert_model_name)
        self.classifier = nn.Linear(self.distilbert.config.hidden_size, num_classes)
        self.projection_layer = nn.Linear(self.distilbert.config.hidden_size, proj_dim)

    def forward(self, input_ids, attention_mask, projection=False):
        outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]  # Use CLS token representation
        if not projection:
            logits = self.classifier(pooled_output)
        else:
            logits = self.projection_layer(pooled_output)
        return logits

def text_classification_model(base_path: str, proj_dim, num_classes, lora_rank):
    print("base path", base_path)
    if "googlebert" in base_path:
        model = BERTClassifier(base_path, num_classes=num_classes, proj_dim=proj_dim)
        
        lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=1,
        target_modules=["query", "value", "key"],
        bias="none",
        )
        
    elif "distilbert" in base_path:
        model = DistilBERTClassifier(base_path, num_classes=num_classes, proj_dim=proj_dim)
        lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=1,
        target_modules=["q_lin", "k_lin", "v_lin"],
        bias="none",
        )
    else:
        raise NotImplementedError

    model = get_peft_model(model, lora_config)

    for layer_name, param in model.named_parameters():
        if ("classifier" in layer_name) or ("projection_layer" in layer_name):
            param.requires_grad = True

    print_trainable_parameters(model)

    return model

def text_classification_eval_function(model, test_loader, train_loader, device):
    def eval_function(evaluation_set: Literal["train", "test"]):
        if evaluation_set == "test":
            loader = test_loader
        elif evaluation_set == "train":
            loader = train_loader
        else:
            raise NotImplementedError

        model.eval()
        model.to(device)
        predictions = []
        actual_labels = []
        all_loss = []
        with torch.no_grad():
            for batch in loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                loss = nn.CrossEntropyLoss()(outputs, labels)
                _, preds = torch.max(outputs, dim=1)
                predictions.extend(preds.cpu().tolist())
                actual_labels.extend(labels.cpu().tolist())
                all_loss.append(loss.cpu().numpy())
        accuracy = accuracy_score(actual_labels, predictions)
        model.to("cpu")
        return {
            "accuracy": accuracy,
            "avg_loss": np.mean(all_loss),
            }
    return eval_function

def text_classification_train_function(
    num_epochs,
    model,
    optimizer,
    train_dataloader,
    public_loader,
    device,
    train_args,
):
    def train_function(num_epochs=num_epochs):
        model.train()
        model.to(device)

        for epoch in range(num_epochs):
            logging.info(f"    | epoch: {epoch}")
            for batch in train_dataloader:
                optimizer.zero_grad()
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                loss = nn.CrossEntropyLoss()(outputs, labels)
                loss.backward()
                optimizer.step()

        model.to("cpu")
        return loss

    return train_function


class TextClassificationDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length,):
        with open(file_path, "rb") as file:
            data = pickle.load(file)
        self.texts = data["texts"]
        self.labels = data["labels"]
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        return {'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'label': torch.tensor(label)}


def create_text_classification_client(
    args,
    base_models_path,
    train_files,
    test_files,
    public_loader,
    lr,
    num_local_epochs,
    num_classes,
    device,
    train_func,
    train_args,
):
    ## Args
    if "googlebert" in base_models_path:
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    elif "distilbert" in base_models_path:
        tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
    else:
        raise NotImplementedError
    
    max_length = 128
    lr = 1e-3
    
    model = text_classification_model(
        base_models_path, args.proj_dim, num_classes, args.rank
    )

    local_optmizer = torch.optim.AdamW(model.parameters(), lr=lr)

    # Create dataset
    train_dataset = TextClassificationDataset(
        file_path=train_files,
        tokenizer=tokenizer,
        max_length=max_length,
    )
    test_dataset = TextClassificationDataset(
        file_path=test_files,
        tokenizer=tokenizer,
        max_length=max_length,
    )

    # Create dataloader
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        num_workers=1,
        drop_last=False,
        pin_memory=True,
        shuffle=True,
    )
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        num_workers=1,
        drop_last=False,
        pin_memory=True,
        shuffle=False,
    )

    text_classification_client =  Client(
        task=f"text_clasification_n_label_{num_classes}",
        model=model,
        local_optimizer=local_optmizer,
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        train_loader=train_loader,
        test_loader=test_loader,
        public_loader=public_loader,
        local_train_func=text_classification_train_function(
            num_local_epochs,
            model,
            local_optmizer,
            train_loader,
            public_loader,
            device,
            train_args,
        ),
        local_eval_fun=text_classification_eval_function(
            model, test_loader, train_loader, device
        ),
        train_args=train_args,
        num_classes=num_classes,
        server_lr=lr,

    )
    text_classification_client.max_length = max_length 
    text_classification_client.tokenizer = tokenizer 
    
    return text_classification_client