# -*- coding: utf-8 -*-
import os
import random
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR
from transformers import (
    DistilBertTokenizer,
    DistilBertForSequenceClassification,
    DistilBertConfig
)
from transformers import (
    T5Tokenizer, 
    T5Config, 
    T5ForSequenceClassification,
    AdamW,
    get_linear_schedule_with_warmup
)

import random

n = 20  
class CSVDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        try:
            newlab=torch.tensor(label, dtype=torch.long) 
        except: newlab=torch.tensor(int(label), dtype=torch.long)
    
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': newlab
        }

def collate_batch(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_masks = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['label'] for item in batch])
    
    return {
        'input_ids': input_ids,
        'attention_mask': attention_masks,
        'labels': labels
    }


CSV_PATH = ""
TEXT_COL = "text"
LABEL_COL = "label"
RANDOM_SEED = 3440
BATCH_SIZE = 8
NUM_EPOCHS = 20  
MAX_LENGTH=128
LR = 5e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"{DEVICE}")


torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

Raodom_Select=False
df = pd.read_csv(CSV_PATH)
assert TEXT_COL in df.columns and LABEL_COL in df.columns, f"CSV {TEXT_COL} {LABEL_COL}"
if Raodom_Select:
    sampled_df = df.sample(n=80, random_state=RANDOM_SEED)
    texts = sampled_df[TEXT_COL].astype(str).tolist()
    raw_labels = sampled_df[LABEL_COL].tolist()
else:
    texts = df[TEXT_COL].astype(str).tolist()
    raw_labels = df[LABEL_COL].tolist()

if all(isinstance(x, (int, float)) and not pd.isna(x) for x in raw_labels):
    labels = [int(x) for x in raw_labels]
    unique_labels = sorted(list(set(labels)))
    label_map = {str(i): i for i in unique_labels}
else:
    unique_labels = sorted(list(set(str(x) for x in raw_labels if pd.notna(x))))
    label_map = {lab: idx for idx, lab in enumerate(unique_labels)}
    labels = [label_map[str(l)] for l in raw_labels]

num_classes = len(set(labels))

model_name = "t5-base" #t5-base,distilbert-base-uncased

if model_name=="distilbert-base-uncased":

    tokenizer = DistilBertTokenizer.from_pretrained(model_name)
    model_config = DistilBertConfig.from_pretrained(
        model_name,
        num_labels=num_classes
    )

    model = DistilBertForSequenceClassification.from_pretrained(
        model_name,
        config=model_config
    )

    model = model.to(DEVICE)
elif model_name=="t5-base":
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model_config = T5Config.from_pretrained(
        model_name,
        num_labels=num_classes
    )
    model = T5ForSequenceClassification.from_pretrained(
        model_name,
        config=model_config
    )

    model = model.to(DEVICE)


from torchtext.datasets import IMDB,AG_NEWS


train_dataset = CSVDataset(texts, labels, tokenizer, MAX_LENGTH)

val_iter = AG_NEWS(split="test")
val_texts, val_labels = [], []

for label, text in val_iter:
    val_texts.append(text)
    val_labels.append(label - 1)  

val_dataset = CSVDataset(val_texts, val_labels, tokenizer, MAX_LENGTH)


train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=collate_batch
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    collate_fn=collate_batch
)


criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

total_steps = NUM_EPOCHS * len(train_loader)
scheduler = LambdaLR(optimizer, lr_lambda=lambda step: max(0.1, 1 - step/total_steps))

def evaluate(model, data_loader, criterion, device):
    model.eval()
    total, correct = 0, 0
    loss_accum = 0.0

    with torch.no_grad():
        for batch in tqdm(data_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            
            loss = criterion(logits, labels)
            loss_accum += loss.item() * labels.size(0)
            
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = correct / total if total > 0 else 0.0
    avg_loss = loss_accum / total if total > 0 else 0.0
    return avg_loss, acc

def train(model, train_loader, val_loader, epochs):
    best_val_acc = 0.0
    
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = 0.0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}")
        for batch in pbar:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)
            
            optimizer.zero_grad()
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            train_loss += loss.item() * labels.size(0)
            pbar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])
        
        train_loss /= len(train_loader.dataset)
        val_loss, val_acc = evaluate(model, val_loader, criterion, DEVICE)
        
        print(f"[Epoch {epoch}] training loss={train_loss:.4f}, evaluatation loss={val_loss:.4f}, acc={val_acc:.4f}")
        
  
        os.makedirs("checkpoints", exist_ok=True)
        torch.save({
            "model_state_dict": model.state_dict(),
            "tokenizer": tokenizer,
            "label_map": label_map,
            "model_config": model_config,
            "val_acc": val_acc
        }, f"checkpoints/best_distilbert_checkpoint.pth")

train(model, train_loader, val_loader, NUM_EPOCHS)

os.makedirs("checkpoints", exist_ok=True)
final_checkpoint_path = f''
torch.save({
    "model_state_dict": model.state_dict(),
    "tokenizer": tokenizer,
    "label_map": label_map,
    "model_config": model_config
}, final_checkpoint_path)