import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPTokenizer
from tqdm import tqdm

# Configuration
class TrainConfig:
    train_embed_path = "/root/autodl-tmp/dataset/embeddings/train/embeddings.pt"
    train_text_path = "/root/autodl-tmp/dataset/text/train/categories.txt"
    embed_dim = 768
    vocab_size = 49408  
    max_seq_len = 77
    batch_size = 256
    lr = 1e-4
    epochs = 300
    mask_prob = 0.15
    num_gpus = 2
    clip_model_path = "/root/autodl-tmp/model/clip" 
    save_dir = "/root/autodl-tmp/untokenizer/untokenizer"
    checkpoint_name = "decoder_epoch{}.pth"

# Dataset
class EmbeddingDataset(Dataset):
    def __init__(self):
        self.embeddings = torch.load(TrainConfig.train_embed_path, map_location='cpu').float()
        with open(TrainConfig.train_text_path, "r") as f:
            self.texts = [line.strip() for line in f]
        self.tokenizer = CLIPTokenizer.from_pretrained(TrainConfig.clip_model_path)
    
    def __len__(self):
        return len(self.embeddings)
    
    def __getitem__(self, idx):
        embedding = self.embeddings[idx]
        text = self.texts[idx]
        tokens = self.tokenizer(
            text,
            max_length=TrainConfig.max_seq_len,
            padding="max_length",
            return_tensors="pt",
            truncation=True
        ).input_ids.squeeze(0)
        return embedding, tokens

# Model
class EmbeddingDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed_proj = nn.Linear(TrainConfig.embed_dim, TrainConfig.embed_dim)
        self.token_embed = nn.Embedding(TrainConfig.vocab_size, TrainConfig.embed_dim)
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=TrainConfig.embed_dim, nhead=8,
            dim_feedforward=2048, dropout=0.1
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        
        # Pre-initialize the full-length position encoding table
        self.pos_encoder = nn.Embedding(TrainConfig.max_seq_len, TrainConfig.embed_dim)
        self.output = nn.Linear(TrainConfig.embed_dim, TrainConfig.vocab_size)
        self._init_weights()
    
    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def forward(self, src_embeds, tgt_tokens):
        """
        src_embeds: [batch, embed_dim]
        tgt_tokens: [batch, seq_len]  (Inputs tokens)
        """
        #  -> memory: [1, batch, embed_dim]
        memory = self.embed_proj(src_embeds).unsqueeze(0)
        
        #  -> [batch, seq_len, embed_dim]
        tgt_embeds = self.token_embed(tgt_tokens)
        
        #  seq_len
        seq_len = tgt_tokens.size(1)
        positions = torch.arange(seq_len, device=tgt_tokens.device)  # [seq_len]
        pos_embeds = self.pos_encoder(positions)                     # [seq_len, embed_dim]
        pos_embeds = pos_embeds.unsqueeze(0)                         # [1, seq_len, embed_dim]
        
        #  -> [batch, seq_len, embed_dim]
        tgt = tgt_embeds + pos_embeds
        
        #  -> [seq_len, batch, embed_dim]
        tgt = tgt.permute(1, 0, 2)
        
        #  -> [seq_len, seq_len]
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(tgt.device)
        
        #  -> [seq_len, batch, embed_dim]
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask)
        
        #  -> [batch, vocab_size, seq_len]
        output = self.output(output)              # [seq_len, batch, vocab_size]
        output = output.permute(1, 2, 0)           # [batch, vocab_size, seq_len]
        return output

# Train
def main():
    os.makedirs(TrainConfig.save_dir, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = EmbeddingDecoder().to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model, device_ids=list(range(TrainConfig.num_gpus)))
    
    dataset = EmbeddingDataset()
    dataloader = DataLoader(dataset, batch_size=TrainConfig.batch_size, shuffle=True)
    
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = torch.optim.AdamW(model.parameters(), lr=TrainConfig.lr)
    
    for epoch in range(TrainConfig.epochs):
        model.train()
        total_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
        
        for embeddings, targets in progress_bar:
            embeddings = embeddings.to(device)
            targets = targets.to(device)
            
            # Mask embedding
            mask = torch.rand_like(embeddings) < TrainConfig.mask_prob
            embeddings = embeddings.clone()
            embeddings[mask] = 0
            
            # Delete token，predict delete token
            inputs = targets[:, :-1]
            labels = targets[:, 1:]
            
            outputs = model(embeddings, inputs)  # [batch, vocab_size, seq_len]
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1} | Avg Loss: {avg_loss:.4f}")
        
       
        ckpt = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
        save_path = os.path.join(TrainConfig.save_dir, TrainConfig.checkpoint_name.format(epoch+1))
        torch.save(ckpt, save_path)

if __name__ == "__main__":
    main()
