import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pickle
import random

class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNNModel, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, lengths):
        packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        packed_output, _ = self.rnn(packed_x)
        output, _ = pad_packed_sequence(packed_output, batch_first=True)
        output = self.dropout(output)
        output = self.fc(output)
        return output

class EmbeddingDataset(Dataset):
    def __init__(self, perturbed_text, original_text):
        self.perturbed_text = [np.array(pt) for pt in perturbed_text]
        self.original_text = [np.array(ot) for ot in original_text]

    def __len__(self):
        return len(self.perturbed_text)

    def __getitem__(self, idx):
        return (torch.FloatTensor(self.perturbed_text[idx]), 
                torch.FloatTensor(self.original_text[idx]),
                len(self.perturbed_text[idx]))

def collate_fn(batch):
    # Sort the batch by sequence length (descending order)
    batch.sort(key=lambda x: x[2], reverse=True)
    perturbed, original, lengths = zip(*batch)
    
    # Pad sequences
    perturbed_padded = nn.utils.rnn.pad_sequence(perturbed, batch_first=True)
    original_padded = nn.utils.rnn.pad_sequence(original, batch_first=True)
    
    return perturbed_padded, original_padded, torch.LongTensor(lengths)

def prepare_data(perturbed_text, original_text):
    assert len(perturbed_text) == len(original_text), "Perturbed and original text lists must have the same length"
    
    for perturbed, original in zip(perturbed_text, original_text):
        assert len(perturbed) == len(original), "Corresponding sublists in perturbed and original text must have the same length"
    
    total_samples = len(perturbed_text)
    train_size = int(0.8 * total_samples)
    val_size = int(0.2 * total_samples)
    test_size = total_samples - train_size - val_size
    
    indices = np.random.permutation(total_samples)
    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size+val_size]
    test_indices = indices[train_size+val_size:]
    
    train_perturbed = [perturbed_text[i] for i in train_indices]
    train_original = [original_text[i] for i in train_indices]
    val_perturbed = [perturbed_text[i] for i in val_indices]
    val_original = [original_text[i] for i in val_indices]
    test_perturbed = [perturbed_text[i] for i in test_indices]
    test_original = [original_text[i] for i in test_indices]
    
    return train_perturbed, train_original, val_perturbed, val_original, test_perturbed, test_original

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, patience=5, clip_grad=1.0):
    """
    Train the model with validation and early stopping
    """
    model.train()
    best_val_loss = float('inf')
    patience_counter = 0
    train_losses = []
    val_losses = []
    best_model_state = None
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_train_loss = 0
        num_train_batches = 0
        
        for perturbed, original, lengths in train_loader:
            perturbed, original = perturbed.to(device), original.to(device)
            if model.training:
                perturbed = perturbed + torch.randn_like(perturbed) * 0.02
            optimizer.zero_grad()
            output = model(perturbed, lengths)
            loss = criterion(output, original)
            loss.backward()
            if clip_grad is not None and clip_grad > 0:
                nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
            optimizer.step()
            total_train_loss += loss.item()
            num_train_batches += 1
        
        avg_train_loss = total_train_loss / num_train_batches
        train_losses.append(avg_train_loss)
        
        # Validation phase
        avg_val_loss = validate_model(model, val_loader, criterion, device)
        val_losses.append(avg_val_loss)
        
        # print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
        
        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            best_model_state = model.state_dict().copy()
            # # Save best model
            # torch.save(model.state_dict(), 'best_model.pth')
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            model.load_state_dict(best_model_state)
            print(f'Early stopping triggered after {epoch+1} epochs')
            break

    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    return train_losses, val_losses

def validate_model(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for perturbed, original, lengths in val_loader:
            perturbed, original = perturbed.to(device), original.to(device)
            output = model(perturbed, lengths)
            loss = criterion(output, original)
            total_loss += loss.item()
            num_batches += 1
    
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    return avg_loss

def save_model(model, path):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

def load_model(model, path, device):
    model.load_state_dict(torch.load(path, map_location=device))
    model.to(device)
    print(f"Model loaded from {path}")
    return model

def predict(model, perturbed_text, device):
    if isinstance(perturbed_text, list):
        perturbed_text = np.array(perturbed_text, dtype=np.float32)
    return batch_predict(model, [perturbed_text], device, batch_size=1)[0]

class InferenceDataset(Dataset):
    def __init__(self, perturbed_text):
        # perturbed_text: List[np.ndarray or list], each shape [T, dim]
        self.perturbed_text = [np.asarray(pt, dtype=np.float32) for pt in perturbed_text]

    def __len__(self):
        return len(self.perturbed_text)

    def __getitem__(self, idx):
        x = self.perturbed_text[idx]
        return torch.from_numpy(x), x.shape[0]


def collate_fn_infer(batch):
    # batch: List[(tensor[T,dim], length)]
    perturbed, lengths = zip(*batch)
    perturbed_padded = nn.utils.rnn.pad_sequence(perturbed, batch_first=True)
    return perturbed_padded, torch.LongTensor(lengths)


def batch_predict(model, perturbed_text_list, device, batch_size=64, num_workers=0):
    ds = InferenceDataset(perturbed_text_list)

    pin = True if (device.type == "cuda") else False
    loader = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin,
        collate_fn=collate_fn_infer
    )

    model.eval()
    outputs = []

    with torch.inference_mode():
        for x_pad, lengths in loader:
            x_pad = x_pad.to(device, non_blocking=True)

            y_pad = model(x_pad, lengths)      # [B, Tmax, dim]
            y_pad = y_pad.detach().cpu().numpy()

            lens = lengths.numpy()
            for i, L in enumerate(lens):
                outputs.append(y_pad[i, :L, :].copy())

    return outputs
