import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pickle
import random

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1, dropout=0.0):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.head_dropout = nn.Dropout(p=0.2)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, lengths):
        if isinstance(lengths, torch.Tensor):
            lengths_cpu = lengths.detach().to('cpu')
        else:
            lengths_cpu = torch.LongTensor(lengths)

        lengths_cpu = torch.clamp(lengths_cpu, max=x.size(1))

        packed_x = pack_padded_sequence(x, lengths_cpu, batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_x)
        output, _ = pad_packed_sequence(packed_output, batch_first=True, total_length=x.size(1))

        output = self.head_dropout(output)
        output = self.fc(output)
        return output


class EmbeddingDataset(Dataset):
    def __init__(self, perturbed_text, original_text):
        self.perturbed_text = [torch.as_tensor(np.asarray(pt), dtype=torch.float32) for pt in perturbed_text]
        self.original_text  = [torch.as_tensor(np.asarray(ot), dtype=torch.float32) for ot in original_text]

    def __len__(self):
        return len(self.perturbed_text)

    def __getitem__(self, idx):
        x = self.perturbed_text[idx]
        y = self.original_text[idx]
        return x, y, x.shape[0]


def collate_fn(batch):
    perturbed, original, lengths = zip(*batch)
    perturbed_padded = nn.utils.rnn.pad_sequence(perturbed, batch_first=True)
    original_padded  = nn.utils.rnn.pad_sequence(original,  batch_first=True)
    return perturbed_padded, original_padded, torch.LongTensor(lengths)


def prepare_data(perturbed_text, original_text):
    assert len(perturbed_text) == len(original_text), "Perturbed and original text lists must have the same length"
    
    for perturbed, original in zip(perturbed_text, original_text):
        assert len(perturbed) == len(original), "Corresponding sublists in perturbed and original text must have the same length"
    
    total_samples = len(perturbed_text)
    train_size = int(0.8 * total_samples)
    val_size = int(0.2 * total_samples)
    test_size = total_samples - train_size - val_size
    
    indices = np.random.permutation(total_samples)
    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size+val_size]
    test_indices = indices[train_size+val_size:]
    
    train_perturbed = [perturbed_text[i] for i in train_indices]
    train_original = [original_text[i] for i in train_indices]
    val_perturbed = [perturbed_text[i] for i in val_indices]
    val_original = [original_text[i] for i in val_indices]
    test_perturbed = [perturbed_text[i] for i in test_indices]
    test_original = [original_text[i] for i in test_indices]
    
    return train_perturbed, train_original, val_perturbed, val_original, test_perturbed, test_original

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, patience=5, clip_grad=1.0):
    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model = None

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for perturbed, original, lengths in train_loader:
            perturbed, original = perturbed.to(device), original.to(device)
            if model.training:
                perturbed = perturbed + torch.randn_like(perturbed) * 0.02
            optimizer.zero_grad()
            output = model(perturbed, lengths)
            loss = criterion(output, original)
            loss.backward()
            if clip_grad is not None and clip_grad > 0:
                nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
            optimizer.step()
            total_loss += loss.item()
        
        avg_train_loss = total_loss / len(train_loader)
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for perturbed, original, lengths in val_loader:
                perturbed, original = perturbed.to(device), original.to(device)
                output = model(perturbed, lengths)
                loss = criterion(output, original)
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
        
        # print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            best_model = model.state_dict().copy()
        else:
            epochs_no_improve += 1
        
        if epochs_no_improve == patience:
            print(f'Early stopping triggered after epoch {epoch+1}')
            model.load_state_dict(best_model)
            break
    
    if best_model is not None:
        model.load_state_dict(best_model)

    return model, best_val_loss

def save_model(model, path, config):
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': config
    }, path)

def load_model(path, device):
    checkpoint = torch.load(path, map_location=device)
    model = LSTMModel(**checkpoint['config'])
    model.load_state_dict(checkpoint['model_state_dict'])
    return model.to(device)

def predict(model, perturbed_text, device):
    model.eval()
    with torch.no_grad():
        if isinstance(perturbed_text, list):
            perturbed_text = np.array(perturbed_text)
        
        perturbed = torch.FloatTensor(perturbed_text).unsqueeze(0).to(device)
        
        lengths = torch.LongTensor([perturbed_text.shape[0]])
        output = model(perturbed, lengths)
        return output.squeeze(0).cpu().numpy()

def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for perturbed, original, lengths in test_loader:
            perturbed, original = perturbed.to(device), original.to(device)
            output = model(perturbed, lengths)
            loss = criterion(output, original)
            total_loss += loss.item()
    
    avg_test_loss = total_loss / len(test_loader)
    return avg_test_loss


class InferenceDataset(Dataset):
    def __init__(self, seqs):
        self.x = [torch.as_tensor(np.asarray(s), dtype=torch.float32) for s in seqs]

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        t = self.x[idx]
        return t, t.shape[0], idx


def collate_infer(batch):
    xs, lens, idxs = zip(*batch)
    x_pad = nn.utils.rnn.pad_sequence(xs, batch_first=True)
    return x_pad, torch.LongTensor(lens), torch.LongTensor(idxs)


def batch_predict(model, seqs, device, batch_size=64):
    ds = InferenceDataset(seqs)
    loader = DataLoader(
        ds, batch_size=batch_size, shuffle=False,
        num_workers=0,
        pin_memory=(device.type == "cuda"),
        collate_fn=collate_infer
    )

    out_list = [None] * len(ds)
    model.eval()
    with torch.inference_mode():
        for x_pad, lens, idxs in loader:
            x_pad = x_pad.to(device, non_blocking=True)
            y_pad = model(x_pad, lens)              # [B, T, D]
            y_pad = y_pad.cpu().numpy()
            lens_np = lens.numpy()
            idxs_np = idxs.numpy()

            for b, idx in enumerate(idxs_np):
                L = int(lens_np[b])
                out_list[idx] = y_pad[b, :L, :].copy()

    return out_list