import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPTokenizer
from tqdm import tqdm

class EvalConfig:
    # Weights
    checkpoint_path = "youpath/decoder_epoch250.pth"
    # val
    val_embed_path = "/root/autodl-tmp/dataset/embeddings/val/embeddings.pt"
    val_text_path  = "/root/autodl-tmp/dataset/text/val/categories.txt"
    embed_dim   = 768
    vocab_size  = 49408
    max_seq_len = 77
    batch_size  = 256
    clip_model_path = "/root/autodl-tmp/model/clip"


class EmbeddingValDataset(Dataset):
    def __init__(self):
        self.embeddings = torch.load(EvalConfig.val_embed_path, map_location="cpu").float()
        with open(EvalConfig.val_text_path, "r") as f:
            self.texts = [line.strip() for line in f]
        self.tokenizer = CLIPTokenizer.from_pretrained(EvalConfig.clip_model_path)

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        emb = self.embeddings[idx]
        txt = self.texts[idx]
        tokens = self.tokenizer(
            txt,
            max_length=EvalConfig.max_seq_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ).input_ids.squeeze(0)
        return emb, tokens, txt

# ———— Define model ————
class EmbeddingDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed_proj = nn.Linear(EvalConfig.embed_dim, EvalConfig.embed_dim)
        self.token_embed = nn.Embedding(EvalConfig.vocab_size, EvalConfig.embed_dim)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=EvalConfig.embed_dim, nhead=8,
            dim_feedforward=2048, dropout=0.1
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        self.pos_encoder = nn.Embedding(EvalConfig.max_seq_len, EvalConfig.embed_dim)
        self.output = nn.Linear(EvalConfig.embed_dim, EvalConfig.vocab_size)

    def forward(self, src_embeds, tgt_tokens):
        # memory: [1, batch, D]
        memory = self.embed_proj(src_embeds).unsqueeze(0)
        # tgt_embeds: [batch, L, D]
        tgt_embeds = self.token_embed(tgt_tokens)
        L = tgt_tokens.size(1)
        positions = torch.arange(L, device=tgt_tokens.device)
        pos_embeds = self.pos_encoder(positions).unsqueeze(0)
        tgt = (tgt_embeds + pos_embeds).permute(1, 0, 2)  # [L, batch, D]
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(L).to(tgt.device)
        out = self.decoder(tgt, memory, tgt_mask=tgt_mask)  # [L, batch, D]
        out = self.output(out)                              # [L, batch, V]
        return out.permute(1, 2, 0)                         # [batch, V, L]

    @torch.no_grad()
    def generate(self, src_embeds, tokenizer, max_len=None):
        """:input src_embeds [batch, D],return list of strings"""
        batch_size = src_embeds.size(0)
        device = src_embeds.device
        memory = self.embed_proj(src_embeds).unsqueeze(0)  # [1, B, D]

       
        bos = tokenizer.bos_token_id or tokenizer.cls_token_id
        seq = torch.full((batch_size, 1), bos, dtype=torch.long, device=device)

        max_len = max_len or EvalConfig.max_seq_len
        finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
        outputs = seq

        for step in range(max_len-1):
            tgt_embeds = self.token_embed(outputs)  # [B, t, D]
            positions = torch.arange(outputs.size(1), device=device)
            pos_embeds = self.pos_encoder(positions).unsqueeze(0)
            tgt = (tgt_embeds + pos_embeds).permute(1,0,2)
            mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(0)).to(device)
            dec = self.decoder(tgt, memory, tgt_mask=mask)  # [t, B, D]
            logits = self.output(dec[-1])                   # [B, V]
            next_tokens = logits.argmax(dim=-1, keepdim=True)  # [B,1]
            outputs = torch.cat([outputs, next_tokens], dim=1) # [B, t+1]
            
            eos = tokenizer.eos_token_id or tokenizer.sep_token_id
            finished = finished | (next_tokens.squeeze(-1) == eos)
            if finished.all():
                break

        # decode
        texts = []
        for seq in outputs:
           
            seq = seq.tolist()
            if eos in seq:
                seq = seq[:seq.index(eos)]
            
            if seq and seq[0] == bos:
                seq = seq[1:]
            texts.append(tokenizer.decode(seq, skip_special_tokens=True))
        return texts


def evaluate():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = EmbeddingDecoder().to(device)
    state = torch.load(EvalConfig.checkpoint_path, map_location=device)
    model.load_state_dict(state)
    model.eval()

    dataset = EmbeddingValDataset()
    dataloader = DataLoader(dataset, batch_size=EvalConfig.batch_size, shuffle=False)

    criterion = nn.CrossEntropyLoss(ignore_index=0, reduction="sum")

    total_loss = 0.0
    total_tokens = 0
    all_refs = []
    all_preds = []

    with torch.no_grad():
        for embeds, tokens, raw_text in tqdm(dataloader, desc="Evaluating"):
            embeds = embeds.to(device)
            tokens = tokens.to(device)

      
            inputs = tokens[:, :-1]
            labels = tokens[:, 1:]
            logits = model(embeds, inputs)  # [B, V, L]
            # flatten for loss
            B, V, L = logits.size()
            loss = criterion(logits.permute(0,2,1).reshape(-1, V), labels.reshape(-1))
            total_loss += loss.item()
            total_tokens += (labels != 0).sum().item()


            preds = model.generate(embeds, dataset.tokenizer)
            all_preds.extend(preds)
            all_refs.extend(raw_text)


    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    print(f"Validation Loss per token: {avg_loss:.4f}")
    print(f"Perplexity: {perplexity:.2f}")


if __name__ == "__main__":
    evaluate()
