import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pickle
import random

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1, dropout=0.0):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.head_dropout = nn.Dropout(p=0.2)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, lengths):
        packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_x)
        output, _ = pad_packed_sequence(packed_output, batch_first=True)
        output = self.head_dropout(output) 
        output = self.fc(output)
        return output

class EmbeddingDataset(Dataset):
    def __init__(self, perturbed_text, original_text):
        self.perturbed_text = [np.array(pt) for pt in perturbed_text]
        self.original_text = [np.array(ot) for ot in original_text]

    def __len__(self):
        return len(self.perturbed_text)

    def __getitem__(self, idx):
        return (torch.FloatTensor(self.perturbed_text[idx]), 
                torch.FloatTensor(self.original_text[idx]),
                len(self.perturbed_text[idx]))

def collate_fn(batch):
    batch.sort(key=lambda x: x[2], reverse=True)
    perturbed, original, lengths = zip(*batch)
    perturbed_padded = nn.utils.rnn.pad_sequence(perturbed, batch_first=True)
    original_padded = nn.utils.rnn.pad_sequence(original, batch_first=True)
    return perturbed_padded, original_padded, torch.LongTensor(lengths)

def prepare_data(perturbed_text, original_text):
    assert len(perturbed_text) == len(original_text), "Perturbed and original text lists must have the same length"
    
    for perturbed, original in zip(perturbed_text, original_text):
        assert len(perturbed) == len(original), "Corresponding sublists in perturbed and original text must have the same length"
    
    total_samples = len(perturbed_text)
    train_size = int(0.8 * total_samples)
    val_size = int(0.2 * total_samples)
    test_size = total_samples - train_size - val_size
    
    indices = np.random.permutation(total_samples)
    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size+val_size]
    test_indices = indices[train_size+val_size:]
    
    train_perturbed = [perturbed_text[i] for i in train_indices]
    train_original = [original_text[i] for i in train_indices]
    val_perturbed = [perturbed_text[i] for i in val_indices]
    val_original = [original_text[i] for i in val_indices]
    test_perturbed = [perturbed_text[i] for i in test_indices]
    test_original = [original_text[i] for i in test_indices]
    
    return train_perturbed, train_original, val_perturbed, val_original, test_perturbed, test_original

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, patience=5, clip_grad=1.0):
    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model = None

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for perturbed, original, lengths in train_loader:
            perturbed, original = perturbed.to(device), original.to(device)
            if model.training:
                perturbed = perturbed + torch.randn_like(perturbed) * 0.02
            optimizer.zero_grad()
            output = model(perturbed, lengths)
            loss = criterion(output, original)
            loss.backward()
            if clip_grad is not None and clip_grad > 0:
                nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
            optimizer.step()
            total_loss += loss.item()
        
        avg_train_loss = total_loss / len(train_loader)
        
        # 验证
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for perturbed, original, lengths in val_loader:
                perturbed, original = perturbed.to(device), original.to(device)
                output = model(perturbed, lengths)
                loss = criterion(output, original)
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
        
        # print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            best_model = model.state_dict().copy()
        else:
            epochs_no_improve += 1
        
        if epochs_no_improve == patience:
            print(f'Early stopping triggered after epoch {epoch+1}')
            model.load_state_dict(best_model)
            break
    
    if best_model is not None:
        model.load_state_dict(best_model)

    return model, best_val_loss

def save_model(model, path, config):
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': config
    }, path)

def load_model(path, device):
    checkpoint = torch.load(path, map_location=device)
    model = LSTMModel(**checkpoint['config'])
    model.load_state_dict(checkpoint['model_state_dict'])
    return model.to(device)

def predict(model, perturbed_text, device):
    model.eval()
    with torch.no_grad():
        if isinstance(perturbed_text, list):
            perturbed_text = np.array(perturbed_text)
        
        perturbed = torch.FloatTensor(perturbed_text).unsqueeze(0).to(device)
        
        lengths = torch.LongTensor([perturbed_text.shape[0]])
        output = model(perturbed, lengths)
        return output.squeeze(0).cpu().numpy()

def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for perturbed, original, lengths in test_loader:
            perturbed, original = perturbed.to(device), original.to(device)
            output = model(perturbed, lengths)
            loss = criterion(output, original)
            total_loss += loss.item()
    
    avg_test_loss = total_loss / len(test_loader)
    return avg_test_loss
