import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pickle
import random
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(self, input_size, d_model=128, nhead=8, 
                 num_encoder_layers=2, dim_feedforward=512, dropout=0.1):
        super().__init__()
        self.model_type = 'Transformer'
        
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, dropout,
            batch_first=True,
            norm_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)
        self.encoder = nn.Linear(input_size, d_model)
        self.decoder = nn.Linear(d_model, input_size)
        
        self._init_weights()

    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, src_key_padding_mask=None):
        src = self.encoder(src)  # (batch_size, seq_len, d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_key_padding_mask=src_key_padding_mask)
        return self.decoder(output)

class EmbeddingDataset(Dataset):
    def __init__(self, perturbed_text, original_text):
        self.perturbed = [torch.FloatTensor(np.array(x)) for x in perturbed_text]
        self.original = [torch.FloatTensor(np.array(x)) for x in original_text]
        
    def __len__(self):
        return len(self.perturbed)
    
    def __getitem__(self, idx):
        return self.perturbed[idx], self.original[idx], len(self.perturbed[idx])

def collate_fn(batch):
    perturbed, original, lengths = zip(*batch)
    max_len = max(lengths)
    
    padded_perturbed = torch.zeros(len(batch), max_len, perturbed[0].shape[-1])
    padded_original = torch.zeros(len(batch), max_len, original[0].shape[-1])
    padding_mask = torch.zeros(len(batch), max_len, dtype=torch.bool)
    
    for i, (p, o, l) in enumerate(zip(perturbed, original, lengths)):
        padded_perturbed[i, :l] = p
        padded_original[i, :l] = o
        padding_mask[i, l:] = True
    
    return padded_perturbed, padded_original, torch.tensor(lengths), padding_mask

def prepare_data(perturbed_text, original_text):
    assert len(perturbed_text) == len(original_text), "Length do not match"
    
    indices = np.random.permutation(len(perturbed_text))
    train_size = int(0.8 * len(indices))
    val_size = int(0.2 * len(indices))
    
    return (
        [perturbed_text[i] for i in indices[:train_size]],
        [original_text[i] for i in indices[:train_size]],
        [perturbed_text[i] for i in indices[train_size:train_size+val_size]],
        [original_text[i] for i in indices[train_size:train_size+val_size]],
        [perturbed_text[i] for i in indices[train_size+val_size:]],
        [original_text[i] for i in indices[train_size+val_size:]]
    )

def train_model(model, train_loader, val_loader, criterion, optimizer, 
               num_epochs, device, scheduler=None, patience=5, clip_grad=1.0):
    best_val_loss = float('inf')
    epochs_no_improve = 0
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for perturbed, original, _, padding_mask in train_loader:
            perturbed = perturbed.to(device)
            original = original.to(device)
            padding_mask = padding_mask.to(device)
            
            if model.training:
                perturbed = perturbed + torch.randn_like(perturbed) * 0.02
            
            optimizer.zero_grad()
            outputs = model(perturbed, src_key_padding_mask=padding_mask)
            loss = criterion(outputs, original)
            loss.backward()
            
            if clip_grad > 0:
                nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
            
            optimizer.step()
            train_loss += loss.item()
        
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for perturbed, original, _, padding_mask in val_loader:
                perturbed = perturbed.to(device)
                original = original.to(device)
                padding_mask = padding_mask.to(device)
                
                outputs = model(perturbed, src_key_padding_mask=padding_mask)
                val_loss += criterion(outputs, original).item()
        
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        if scheduler:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(avg_val_loss)
            else:
                scheduler.step()
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            best_weights = model.state_dict().copy()
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                # print(f'Early stopping at epoch {epoch+1}')
                model.load_state_dict(best_weights)
                break
        
        # print(f'Epoch {epoch+1}/{num_epochs}, '
        #       f'Train Loss: {avg_train_loss:.4f}, '
        #       f'Val Loss: {avg_val_loss:.4f}')
    
    if 'best_weights' in locals():
        model.load_state_dict(best_weights)

    return model, best_val_loss

def predict(model, perturbed_text, device=None):
    model.eval()
    if device is None:
        device = next(model.parameters()).device
    
    if isinstance(perturbed_text, list):
        perturbed_text = np.array(perturbed_text)
    
    perturbed = torch.FloatTensor(perturbed_text).unsqueeze(0).to(device)  # (1, seq_len, 100)
    
    seq_len = perturbed.size(1)
    padding_mask = torch.zeros(1, seq_len, dtype=torch.bool).to(device)
    
    with torch.no_grad():
        output = model(perturbed, src_key_padding_mask=padding_mask)
        return output.squeeze(0).cpu().numpy()

def batch_predict(model, embeddings_list, batch_size=32):
    model.eval()
    device = next(model.parameters()).device
    
    predicted = []
    for i in range(0, len(embeddings_list), batch_size):
        batch = embeddings_list[i:i+batch_size]
        if len(batch) == 0:
            continue

        max_len = max(len(seq) for seq in batch)
        feat_dim = len(batch[0][0])

        # [batch_size, max_len, feat_dim]
        padded_batch = torch.zeros(
            len(batch), max_len, feat_dim,
            dtype=torch.float32,
            device=device
        )
        padding_mask = torch.zeros(
            len(batch), max_len,
            dtype=torch.bool,
            device=device
        )

        for j, seq in enumerate(batch):
            seq_tensor = torch.as_tensor(seq, dtype=torch.float32, device=device)
            seq_len = seq_tensor.shape[0]
            padded_batch[j, :seq_len] = seq_tensor
            padding_mask[j, seq_len:] = True

        with torch.no_grad():
            outputs = model(
                padded_batch,
                src_key_padding_mask=padding_mask
            ).cpu().numpy()

        for j, seq in enumerate(batch):
            seq_len = len(seq)
            predicted.append(outputs[j, :seq_len])

    return predicted