import numpy as np
import torch
from sklearn.metrics import mean_squared_error
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW

from gradiend.setups.emotion import EmotionSetup

setup = EmotionSetup()

model_id = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=3)

train_data = setup.create_training_data(tokenizer, batch_size=1, max_size=10000)  # Should return dict with 'input_ids', 'attention_mask', 'labels'
eval_data = setup.create_training_data(tokenizer, split='val', max_versions_per_masked_word=100, batch_size=1)


# 2. Custom Dataset
class EmotionDataset(Dataset):
    def __init__(self, data):
        self.data = data

        # tokenize the data
        #text = data['template']
        #labels = data['antonym']
        #tokenized_text = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
        #input_ids = tokenized_text['input_ids']
        #attention_mask = tokenized_text['attention_mask']
        # Convert labels to the required format



        #self.input_ids = data['input_ids']
        #self.attention_mask = data['attention_mask']
        #self.labels = data['labels']  # Should be list of [arousal, valence, dominance]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        model_input =  item[False]
        model_input['labels'] = item['label']
        return model_input


        return {
            'input_ids': torch.tensor(self.input_ids[idx], dtype=torch.long),
            'attention_mask': torch.tensor(self.attention_mask[idx], dtype=torch.long),
            'labels': torch.tensor(self.labels[idx], dtype=torch.float),
        }

train_dataset = EmotionDataset(train_data)
eval_dataset = EmotionDataset(eval_data)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False)
eval_loader = DataLoader(eval_dataset, batch_size=64)

# 3. Model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=3)  # 3 regression outputs
model.to(device)

optimizer = AdamW(model.parameters(), lr=1e-5)
loss_fn = nn.MSELoss()

# 4. Training loop
def train(model, loader):
    model.train()
    total_loss = 0
    for batch in loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = torch.stack(batch['labels']).T.to(device, dtype=model.dtype)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = outputs.logits  # shape: (batch_size, 3)
        loss = loss_fn(predictions, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(loader)

# 5. Evaluation loop
def evaluate(model, loader):
    model.eval()
    preds = []
    targets = []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = torch.stack(batch['labels']).to(device, dtype=model.dtype)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            predictions = outputs.logits

            preds.append(predictions.cpu().numpy())
            targets.append(labels.cpu().numpy())

    preds = np.concatenate(preds, axis=0)
    targets = np.concatenate(targets, axis=1).T

    mse = mean_squared_error(targets, preds, multioutput='raw_values')
    mse_avg = mean_squared_error(targets, preds, multioutput='uniform_average')

    # compute pearson correlation
    correlations = []
    for i in range(preds.shape[1]):
        corr = np.corrcoef(targets[:, i], preds[:, i])[0, 1]
        correlations.append(corr)
    correlations = np.array(correlations)
    mean_cor = np.mean(correlations)
    print(f"Pearson Correlation: {mean_cor:.4f}, Arousal: {correlations[0]:.4f}, Valence: {correlations[1]:.4f}, Dominance: {correlations[2]:.4f}")

    print(f"Evaluation MSE: {mse}, Average MSE: {mse_avg}")
    return {
        "mse_arousal": mse[0],
        "mse_valence": mse[1],
        "mse_dominance": mse[2],
        "mse_avg": mse_avg,
    }

# 6. Training loop across epochs
num_epochs = 5
for epoch in range(num_epochs):
    train_loss = train(model, train_loader)
    eval_metrics = evaluate(model, eval_loader)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"  Train Loss     : {train_loss:.4f}")
    print(f"  Eval MSE (avg) : {eval_metrics['mse_avg']:.4f}")
    print(f"    Arousal      : {eval_metrics['mse_arousal']:.4f}")
    print(f"    Valence      : {eval_metrics['mse_valence']:.4f}")
    print(f"    Dominance    : {eval_metrics['mse_dominance']:.4f}")