import logging
import os
from typing import Any, Callable, List, Literal, Optional, Tuple, Union

import numpy as np

import torch
from torch import nn
import pickle
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.nn import CrossEntropyLoss
from peft import LoraConfig, get_peft_model
from sklearn.metrics import accuracy_score

from clients.base import Client
from utils.model_ops import print_trainable_parameters


class GPT2Generator(nn.Module):
    def __init__(self, model_name, proj_dim):
        super(GPT2Generator, self).__init__()
        self.gpt2 = GPT2LMHeadModel.from_pretrained(model_name)
        self.projection_layer = nn.Linear(self.gpt2.config.hidden_size, proj_dim)

    def forward(self, input_ids, attention_mask, projection=False):
        outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        last_hidden_state = outputs.logits[:, -1, :]
        if projection:
            logits = self.projection_layer(last_hidden_state)
        else:
            logits = outputs.logits
        return logits


def text_generation_model(base_path: str, proj_dim, lora_rank):
    model = GPT2Generator(base_path, proj_dim=proj_dim)
    
    lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=1,
        target_modules=["c_attn", "c_proj"],
        bias="none",
    )
    model = get_peft_model(model, lora_config)

    print_trainable_parameters(model)

    return model


class ARCChallengeDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length):
        with open(file_path, "rb") as file:
            data = pickle.load(file)
        self.prompts = data["prompts"]
        self.responses = data["responses"]
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        prompt = self.prompts[idx]
        response = self.responses[idx]
        encoding = self.tokenizer(
            prompt, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True
        )
        target_encoding = self.tokenizer(
            response, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': target_encoding['input_ids'].flatten()
        }


def text_generation_train_function(model, train_loader, optimizer, device, num_epochs):
    def train_function():
        model.train()
        model.to(device)
        loss_fn = CrossEntropyLoss(ignore_index=-100)
        
        for epoch in range(num_epochs):
            logging.info(f"Epoch: {epoch+1}")
            total_loss = 0
            for batch in train_loader:
                optimizer.zero_grad()
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                loss = loss_fn(outputs.view(-1, outputs.size(-1)), labels.view(-1))
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            logging.info(f"Epoch {epoch+1} Loss: {total_loss / len(train_loader)}")
        model.to("cpu")
    return train_function


def text_generation_eval_function(model, test_loader, device):
    def eval_function():
        model.eval()
        model.to(device)
        loss_fn = CrossEntropyLoss(ignore_index=-100)
        total_loss = 0
        with torch.no_grad():
            for batch in test_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                loss = loss_fn(outputs.view(-1, outputs.size(-1)), labels.view(-1))
                total_loss += loss.item()
        model.to("cpu")
        return {"avg_loss": total_loss / len(test_loader)}
    return eval_function


def create_text_generation_client(
    args,
    base_models_path,
    train_files,
    test_files,
    public_loader,
    lr,
    num_local_epochs,
    device,
    train_args,
):
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
    tokenizer.pad_token = tokenizer.eos_token
    max_length = 128
    lr = 1e-3
    
    model = text_generation_model(
        base_models_path, args.proj_dim, args.rank
    )

    local_optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    train_dataset = ARCChallengeDataset(
        file_path=train_files,
        tokenizer=tokenizer,
        max_length=max_length,
    )
    test_dataset = ARCChallengeDataset(
        file_path=test_files,
        tokenizer=tokenizer,
        max_length=max_length,
    )

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        num_workers=1,
        drop_last=False,
        pin_memory=True,
        shuffle=True,
    )
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        num_workers=1,
        drop_last=False,
        pin_memory=True,
        shuffle=False,
    )

    text_generation_client = Client(
        task="text_generation",
        model=model,
        local_optimizer=local_optimizer,
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        train_loader=train_loader,
        test_loader=test_loader,
        public_loader=public_loader,
        local_train_func=text_generation_train_function(model, train_loader, local_optimizer, device, num_local_epochs),
        local_eval_fun=text_generation_eval_function(model, test_loader, device),
        train_args=train_args,
        server_lr=lr,
    )
    text_generation_client.max_length = max_length
    text_generation_client.tokenizer = tokenizer
    
    return text_generation_client
