from . import model as M
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from tqdm import tqdm

class ActivationDataset(Dataset):
    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        item_info = self.data_list[idx]
        file_path = item_info['path']
        label = item_info['label']
        raw = torch.load(file_path, map_location='cpu')
        
        # 还原你原本的数据处理逻辑
        squeezed = [t.squeeze(0) for t in raw]
        t = torch.stack(squeezed, dim=0)
        
        return {'activation': t, 'label': label}
    
    def split(self, split_idx):
        train_data = self.data_list[:split_idx]
        eval_data = self.data_list[split_idx:]
        return ActivationDataset(train_data), ActivationDataset(eval_data)

def padding_collate_fn(batch):
    raw_tensors = [item['activation'].squeeze(0) if item['activation'].dim() == 4 else item['activation'] for item in batch]
    labels = torch.tensor([item['label'] for item in batch], dtype=torch.float32)

    lengths = [tensor.size(1) for tensor in raw_tensors]

    batch_size = len(raw_tensors)
    num_layers = raw_tensors[0].size(0)
    hidden_dim = raw_tensors[0].size(2)

    max_length = max(lengths)
    padded_batch = torch.zeros(batch_size, num_layers, max_length, hidden_dim)
    mask = torch.zeros(batch_size, max_length)

    for i, (seq, length) in enumerate(zip(raw_tensors, lengths)):
        padded_batch[i, :, :length, :] = seq
        mask[i, :length] = 1
    
    return padded_batch, mask, labels


def load_model(input_dim=2048, num_filters=64, layer_kernel_size=3, dropout=0.5, pooling='max'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    return M.hiddenDetector(
        input_dim=input_dim,
        num_filters=num_filters,
        layer_kernel_size=layer_kernel_size,
        dropout=dropout,
        pooling=pooling
    ).to(device)

def load_demo_dataset(dataset_path='residuals/xstest'):
    from pathlib import Path
    import random
    
    print(f'dataset path: {dataset_path}')
    
    safe_list, unsafe_list = [], []
    
    safe_folder = Path(f"{dataset_path}/safe")
    unsafe_folder = Path(f"{dataset_path}/unsafe")

    safe_files = list(safe_folder.glob('*.pt'))
    safe_files.sort()

    for f in safe_files:
        # raw = torch.load(f, map_location='cpu')
        # squeezed = [t.squeeze(0) for t in raw]
        # t = torch.stack(squeezed, dim=0)
        safe_list.append({'path': f, 'label': 0})
    
    unsafe_files = list(unsafe_folder.glob('*.pt'))
    unsafe_files.sort()

    for f in unsafe_files:
        # raw = torch.load(f, map_location='cpu')
        # squeezed = [t.squeeze(0) for t in raw]
        # t = torch.stack(squeezed, dim=0)
        unsafe_list.append({'path': f, 'label': 1})
    
    all_list = []

    for item in safe_list:
        # all_list.append({ 'activation': item, 'label': 0 })  # 0 表示安全
        all_list.append(item)

    for item in unsafe_list:
        # all_list.append({ 'activation': item, 'label': 1 })  # 1 表示不安全
        all_list.append(item)

    random.shuffle(all_list)

    print('Safe samples: {}, Unsafe samples: {}'.format(len(safe_list), len(unsafe_list)))

    dataset = ActivationDataset(all_list)

    return dataset

def predict(model, prompt, llm, tokenizer, device, start_layer = 10, end_layer = 15):
    model.eval()
    with torch.no_grad():
        chat = [
            {"role": "user", "content": prompt['prompt']}
        ]
        chat_tokens = tokenizer.apply_chat_template(chat, tokenize=False, return_tensors="pt", add_generation_token=True)
        if prompt['injection'] is not None:
            chat_tokens += prompt['injection']
        chat_tokens = tokenizer.encode(chat_tokens, return_tensors="pt", add_special_tokens=False)
        
        chat_tokens = chat_tokens.to(device)

        outputs = llm(chat_tokens, output_hidden_states=True)
        
        hidden_states = torch.stack(outputs.hidden_states[start_layer:end_layer])  # (layers, batch_size, seq_len, hidden_dim)

        hidden_states = hidden_states.permute(1, 0, 2, 3) # (batch_size, layers, seq_len, hidden_dim)

        batch_size = hidden_states.size(0)
        seq_len = hidden_states.size(2) 

        mask = torch.ones((batch_size, seq_len), dtype=torch.float32).to(device)

        logits = model(hidden_states, mask)  # (1, 1)
        probs = torch.sigmoid(logits).squeeze(0).squeeze(0).item()
        return probs

def eval(model, eval_dataloader, criterion):
    model.eval()
    total_loss = 0.0
    total_samples = 0

    total_acc = 0.0

    with torch.no_grad():
        for inputs, masks, labels in eval_dataloader:
            inputs, masks, labels = inputs.to(model.device), masks.to(model.device), labels.to(model.device)

            outputs = model(inputs, masks)
            loss = criterion(outputs, labels)

            predictions = (torch.sigmoid(outputs) >= 0.5).float()

            total_acc += (predictions == labels).sum().item()
            total_loss += loss.item() * inputs.size(0)
            total_samples += inputs.size(0)

    avg_loss = total_loss / total_samples
    avg_acc = total_acc / total_samples

    return avg_loss, avg_acc

def set_seed(seed):
    import random
    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def fit(num_epochs=50, input_dim=2048, lr=0.001, batch_size=16, seed=999, pooling='max', use_demo=True, dataset_path=None):
    set_seed(seed)
    if use_demo:
        dataset = load_demo_dataset(dataset_path=dataset_path)
        train_dataset, eval_dataset = dataset.split(int(0.8 * len(dataset)))

        # test_dataset = load_demo_dataset(dataset_path='residuals/xstest')

        data_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=padding_collate_fn
        )

        eval_dataloader = DataLoader(
            eval_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=padding_collate_fn
        )

        # test_dataloader = DataLoader(
        #     test_dataset,
        #     batch_size=batch_size,
        #     shuffle=False,
        #     collate_fn=padding_collate_fn
        # )

    else:
        raise NotImplementedError("Custom dataset loading not implemented yet.")
    
    model = load_model(input_dim=input_dim, pooling=pooling)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 
        T_max=num_epochs, # 比如 50
        eta_min=1e-5      # 最小降到多少
    )
    criterion = torch.nn.BCEWithLogitsLoss()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    max_acc = 0

    for epoch in range(num_epochs):
        model.train()
        for batch_idx, (inputs, masks, labels) in enumerate(data_loader):
            inputs, masks, labels = inputs.to(device), masks.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs, masks)

            loss = criterion(outputs, labels)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(data_loader)}], Loss: {loss.item():.4f}")
        scheduler.step()

        eval_avg_loss, eval_avg_acc = eval(model, eval_dataloader, criterion)
        max_acc = max(max_acc, eval_avg_acc)

        print(f"Epoch [{epoch+1}/{num_epochs}] Evaluation Loss: {eval_avg_loss:.4f}, Accuracy: {eval_avg_acc:.4f}")
    print(f"Training complete. Max Accuracy: {max_acc:.4f}")

    return model, max_acc

def train_probe(num_epochs=10, input_dim=4096, seed=42, batch_size=16, dataset_path='residuals/qwen2.5-7B/train'):
    set_seed(seed)
    dataset = load_demo_dataset(dataset_path=dataset_path)
    train_dataset, eval_dataset = dataset.split(int(0.8 * len(dataset)))

    data_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=padding_collate_fn
    )

    eval_dataloader = DataLoader(
        eval_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=padding_collate_fn
    )

    model = M.linearProbe(input_dim=input_dim)
    model.to(model.device)

    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-2)
    criterion = torch.nn.BCEWithLogitsLoss()

    for epoch in range(num_epochs):
        model.train()
        for batch_idx, (inputs, masks, labels) in enumerate(data_loader):
            inputs, masks, labels = inputs.to(model.device), masks.to(model.device), labels.to(model.device)
            last_layer_activations = inputs[:, -1, :, :]
            
            # 2. 计算每个样本真实的长度 (mask 为 1 的个数)
            # mask: (batch, seq_len) -> sum -> (batch,)
            seq_lengths = masks.sum(dim=1).long()
            
            # 3. 最后一个有效 token 的索引是 length - 1
            last_token_indices = seq_lengths - 1
            
            # 4. 利用高级索引提取真实的 last token
            # inputs[i, last_token_indices[i], :]
            inputs = last_layer_activations[torch.arange(inputs.size(0)), last_token_indices]

            optimizer.zero_grad()
            outputs = model(inputs)

            loss = criterion(outputs, labels)

            loss.backward()

            optimizer.step()
            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(data_loader)}], Loss: {loss.item():.4f}")
        
        # eval_avg_loss, eval_avg_acc = eval(model, eval_dataloader, criterion)

        # print(f"Epoch [{epoch+1}/{num_epochs}] Evaluation Loss: {eval_avg_loss:.4f}, Accuracy: {eval_avg_acc:.4f}")
    return model