from random import random

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from transformers import AutoModel

from models.classifier import Classifier
from utils.data_utils import load_stance_detection_data


def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)


class Sentence_Transformer_ST(nn.Module, Classifier):

    def get_tokenizer(self):
        return AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')

    def get_representation_model(self):
        return self.lm

    def get_classifier(self):
        return self.classification_head

    def get_embeddings(self, lst_texts, batch_size=32):
        self.eval()  # Set the model to evaluation mode
        tokenizer = self.tokenizer
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # Split input list into batches
        batches = [lst_texts[i:i + batch_size] for i in range(0, len(lst_texts), batch_size)]
        embeddings_list = []
        self.lm.to(device)
        with torch.no_grad():
            for batch in batches:
                inputs = tokenizer(batch, truncation=True, padding=True, return_tensors='pt', max_length=128)
                input_ids = inputs['input_ids'].to(device)
                attention_mask = inputs['attention_mask'].to(device)

                outputs = self.lm(input_ids=input_ids, attention_mask=attention_mask)
                embeddings = outputs['pooler_output']  # Adapt this based on the model's output structure
                embeddings_list.append(embeddings.cpu().numpy())
        # del input_ids and attention_mask
        del attention_mask
        del input_ids
        self.lm.to('cpu')

        return list(np.vstack(embeddings_list))

    def __init__(self, pretrained_model_name, num_classes=5, setup_name='cebab'):
        super(Sentence_Transformer_ST, self).__init__()
        self.setup_name = setup_name
        self.lm = AutoModel.from_pretrained(pretrained_model_name)

        self.classification_head = nn.Sequential(
            nn.Linear(self.lm.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
        self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')

    def forward(self, input_ids, attention_mask):
        outputs = self.lm(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs['pooler_output']  # You might need to adapt this based on the model's output structure
        logits = self.classification_head(pooled_output)
        return logits

    def train_clf(self):
        model = self
        if self.setup_name == 'cebab':
            train_set = pd.read_csv('sets/sources/train_set_42.csv')
            val_set = pd.read_csv('sets/sources/validation.csv')
            # Replace these with your actual dataset and labels
            train_texts = list(train_set['description'].values)  # List of text data
            train_labels = list(train_set['review_majority'].values)
            val_texts = list(val_set['description'].values)  # List of text data
            val_labels = list(val_set['review_majority'].values)
        else:
            data = load_stance_detection_data()
            train_texts = list(data['train_base']['instruction'])
            train_labels = list(data['train_base']['label'])
            val_texts = list(data['test_base']['instruction'])
            val_labels = list(data['test_base']['label'])

        # List of corresponding labels

        # # Split dataset into training and validation sets
        # train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2, random_state=42)
        #
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
        max_length = 128  # Adjust as needed

        # Create dataset instances
        train_dataset = CustomDataset(train_texts, train_labels, tokenizer, max_length)
        val_dataset = CustomDataset(val_texts, val_labels, tokenizer, max_length)

        # Create data loaders
        batch_size = 32
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

        # Training loop
        num_epochs = 5  # Adjust as needed
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.to(device)

        for epoch in range(num_epochs):
            model.train()
            for batch in train_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)

                optimizer.zero_grad()
                logits = model(input_ids, attention_mask)
                loss = criterion(logits, labels)
                loss.backward()
                optimizer.step()

            model.eval()
            val_loss = 0.0
            correct_preds = 0
            total_preds = 0
            with torch.no_grad():
                for batch in val_loader:
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['label'].to(device)

                    logits = model(input_ids, attention_mask)
                    val_loss += criterion(logits, labels).item()

                    _, predicted = torch.max(logits, dim=1)
                    correct_preds += (predicted == labels).sum().item()
                    total_preds += labels.size(0)

            val_accuracy = correct_preds / total_preds
            val_loss /= len(val_loader)
            print(f"Epoch [{epoch + 1}/{num_epochs}] - Val Loss: {val_loss:.4f} - Val Accuracy: {val_accuracy:.4f}")

        print("Training finished!")

        save_path = '/home/XXXXXX/MatchingBasedCausalExplanation/saved_models/sentiment_models/ST_finetuned/model.pth'
        torch.save(model.state_dict(), save_path)

        print(f"Trained model weights saved at '{save_path}'")


# Define your dataset class
class CustomDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length,
                                  return_tensors='pt')
        input_ids = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()
        return {'input_ids': input_ids, 'attention_mask': attention_mask, 'label': label}


set_seed(42)
# Instantiate the custom classifier
num_classes = 5  # Change this based on your classification task
# model = CustomClassifier('sentence-transformers/all-mpnet-base-v2', num_classes)

# Load and preprocess your dataset
