import os
import json
import torch
import time
import random
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from transformers import RobertaModel
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, f1_score
from sklearn.utils.class_weight import compute_class_weight
from collections import Counter
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def load_reasoning_step_training_data(data_file, labels_file, target_label):
    df_steps = pd.read_json(data_file)
    texts = df_steps['step'].tolist()

    df_labels = pd.read_json(labels_file)
    labels = df_labels['category'].tolist()

    # Create binary labels based on target label (no shuffling!)
    binary_labels = [1 if label == target_label else 0 for label in labels]
    unique_labels = [0, 1]

    return texts, binary_labels, unique_labels


def load_reasoning_step_testing_data(data_file, labels_file, target_label, evaluation=False):
    df_steps = pd.read_json(data_file)
    texts = df_steps['step'].tolist()

    df_labels = pd.read_json(labels_file)
    labels = df_labels['category'].tolist()

    binary_labels = [1 if label == target_label else 0 for label in labels]

    if not evaluation:
        return texts, binary_labels
    else:
        index_list = df_steps['sample_id'].tolist()
        step_ids_list = df_steps['step_id'].tolist()
        return texts, binary_labels, index_list, step_ids_list

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": torch.tensor(label, dtype=torch.long)
        }

class BERTClassifier(nn.Module):
    def __init__(self, model_name, num_classes):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]  # CLS token
        x = self.dropout(pooled_output)
        return self.fc(x)

def train(model, loader, optimizer, scheduler, device, loss_fn):
    model.train()
    total_loss = 0
    for batch in tqdm(loader, desc="Training"):
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_ids, attention_mask)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, device):
    model.eval()
    preds, labels_all = [], []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids, attention_mask)
            preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
            labels_all.extend(labels.cpu().numpy())

    acc = accuracy_score(labels_all, preds)
    f1 = f1_score(labels_all, preds, average="weighted")
    report = classification_report(labels_all, preds)
    return acc, f1, report