import torch
import torch.nn as nn
import logging
import argparse
from torch.utils.data import Dataset, DataLoader
import json
from task_tracker.utils.model import load_model
from task_tracker.config.models import models, cache_dir, database_dir
from task_tracker.ragsys.load import load_json_file, mean_pooling
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
import faiss
from transformers import AutoTokenizer, AutoModel

# 定义LoRA层：对Transformer的自注意力机制进行低秩适配
class LoRALayer(nn.Module):
    def __init__(self, input_dim, rank=8, layer_name="lora"):
        super(LoRALayer, self).__init__()
        self.rank = rank
        # 定义两个低秩矩阵
        self.lora_A = nn.Linear(input_dim, rank, bias=False)
        self.lora_A.name = f"{layer_name}_A"
        self.lora_B = nn.Linear(rank, input_dim, bias=False)
        self.lora_B.name = f"{layer_name}_B"

    def forward(self, x):
        return self.lora_B(self.lora_A(x))  # 低秩适配计算

# 定义LoRA增强的RAG模型
class LoRARAG(nn.Module):
    def __init__(self, base_model, model_tokenizer, rank=8, index=None, emb_model=None, tokenizer=None, documents=None):
        super(LoRARAG, self).__init__()
        self.base_model = base_model
        self.base_model = self.base_model.to('cuda' if torch.cuda.is_available() else 'cpu')
        self.index = index
        self.emb_model = emb_model
        self.tokenizer = tokenizer
        self.documents = documents
        self.model_tokenizer = model_tokenizer

        # 使用get_decoder方法获取解码器
        #decoder = self.base_model.get_decoder() if hasattr(self.base_model, 'get_decoder') else None
        # 假设decoder存在并且包含block层
        for idx, layer in enumerate(self.base_model.get_decoder().layers):
            layer.self_attn.q_proj = LoRALayer(layer.self_attn.q_proj.in_features, rank, f"lora_q_proj_{idx}")
            layer.self_attn.k_proj = LoRALayer(layer.self_attn.k_proj.in_features, rank, f"lora_k_proj_{idx}")
            layer.self_attn.v_proj = LoRALayer(layer.self_attn.v_proj.in_features, rank, f"lora_v_proj_{idx}")

            self.add_module(f"lora_q_proj_{idx}", layer.self_attn.q_proj)
            self.add_module(f"lora_k_proj_{idx}", layer.self_attn.k_proj)
            self.add_module(f"lora_v_proj_{idx}", layer.self_attn.v_proj)
            # layer.self_attn.q_proj = LoRALayer(layer.self_attn.q_proj.in_features, rank, f"lora_q_proj_{idx}")
            # print(layer.self_attn.q_proj)
            # print(layer.self_attn.q_proj.parameters())
            # layer.self_attn.k_proj = LoRALayer(layer.self_attn.k_proj.in_features, rank, f"lora_k_proj_{idx}")
            # layer.self_attn.v_proj = LoRALayer(layer.self_attn.v_proj.in_features, rank, f"lora_v_proj_{idx}")

    
    def forward(self, input_ids, attention_mask, query, top_k=5, labels=None):
        # 对批次中的每个query进行处理
        batch_size = len(query)  # 获取批次大小
        retrieved_docs = [self.retrieve_documents(query[i], top_k) for i in range(batch_size)]
        
        # 拼接输入
        combined_inputs = []
        for i in range(batch_size):
            combined_input = (
                "Consider the following request that you must answer based on the given text: "
                + query[i] 
                + "<CONTEXT>"
                + retrieved_docs[i]
                + "</CONTEXT>"
            )
            combined_inputs.append(combined_input)

        # 编码所有输入
        combined_input_encoded = self.model_tokenizer(
            combined_inputs, 
            padding=True, 
            truncation=True, 
            return_tensors="pt"
        ).to(input_ids.device)

        outputs = self.base_model(
            input_ids=combined_input_encoded["input_ids"],
            attention_mask=attention_mask,
            labels=labels
        )
        hidden_states = outputs.encoder_last_hidden_state

        # 将LoRA向量加到隐藏状态中
        for i, layer in enumerate(self.base_model.model.decoder.block):
            hidden_state = hidden_states[i]
            hidden_states[i] = hidden_state + layer.attention.self.query(hidden_state)

        outputs.encoder_last_hidden_state = hidden_states
        return outputs

    def retrieve_documents(self, query, top_k=5):
        # 使用FAISS进行文档检索
        query_embedding = self.encode_query(query)
        
        # 执行检索，得到与查询相关的文档
        distances, indices = self.index.search(query_embedding, top_k)
        
        # 从索引中提取对应的文档内容
        retrieved_docs = []
        for idx in indices[0]:
            retrieved_docs.append(self.documents[idx])  # 假设文档存储在self.index.documents中
        
        return " ".join(retrieved_docs)

    def encode_query(self, query):
        # 将查询编码成向量表示
        query_encoded = self.tokenizer(query, padding=True, truncation=True, return_tensors='pt')
        with torch.no_grad():
            query_output = self.emb_model(**query_encoded)
        query_embedding = mean_pooling(query_output, query_encoded["attention_mask"])
        query_embedding = F.normalize(query_embedding, p=2, dim=1)
        return query_embedding

class MitigationLoss(nn.Module):
    def __init__(self, alpha=0.7):
        super(MitigationLoss, self).__init__()
        self.alpha = alpha

    def forward(self, outputs, target_labels, safe_labels):
        # 交叉熵损失，用于生成任务的准确性
        ce_loss = nn.CrossEntropyLoss()(outputs.logits.view(-1, outputs.logits.size(-1)), target_labels.view(-1))

        # 计算激活偏移指数（ASI）损失
        hidden_states = outputs.encoder_last_hidden_state
        asi_loss = self.activation_shift_index_loss(hidden_states, safe_labels)

        return self.alpha * ce_loss + (1 - self.alpha) * asi_loss

    def activation_shift_index_loss(self, hidden_states, safe_labels):
        asi = hidden_states.mean(dim=1)
        asi_loss = nn.BCEWithLogitsLoss()(asi.squeeze(), safe_labels.float())
        return asi_loss

# 定义数据集类
class SafetyDataset(Dataset):
    def __init__(self, data_path, tokenizer, max_length=128):
        with open(data_path) as f:
            self.data = json.load(f)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        inputs = self.tokenizer(
            sample["task_prompt"],
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        targets = self.tokenizer(
            sample["answer"],
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return {
            "input_ids": inputs.input_ids.squeeze(),
            "attention_mask": inputs.attention_mask.squeeze(),
            "decoder_input_ids": targets.input_ids.squeeze(0).unsqueeze(0)[:, :-1],
            "labels": targets.input_ids.squeeze(0).unsqueeze(0)[:, 1:],
            "safe_labels": torch.tensor(sample["safe_label"], dtype=torch.long),
            "task_prompt": sample["task_prompt"]
        }

# 定义训练函数
def train_model(model, dataloader, optimizer, num_epochs=10):
    model.train()
    loss_fn = MitigationLoss(alpha=0.7)
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        for batch in dataloader:
            optimizer.zero_grad()

            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels'],
                query=batch['task_prompt']
            )

            # 计算损失并反向传播
            loss = loss_fn(outputs, batch['safe_labels'])
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss / len(dataloader):.4f}")

def build_index(file_path, model_path, index_name, batch_size=500):  # Reduce batch size
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    emb_model = AutoModel.from_pretrained(model_path).cuda()
    documents = load_json_file(file_path, "context")

    index = None
    for i in range(0, len(documents), batch_size):
        batch_documents = documents[i:i + batch_size]
        encoded_input = tokenizer(batch_documents, padding=True, truncation=True, return_tensors="pt").to('cuda')
        with torch.no_grad():
            model_output = emb_model(**encoded_input)
        embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
        embeddings = F.normalize(embeddings, p=2, dim=1)

        if index is None:
            dimension = embeddings.shape[1]
            index = faiss.IndexFlatL2(dimension)
        index.add(embeddings.cpu().numpy())

    faiss.write_index(index, index_name)
    print("FAISS index stored in index.faiss")

    return index

# 主函数：加载模型、数据集和训练
def main(model_name, model_path):
    model = models[model_name]

    try:
        # Load the model and tokenizer
        loaded_model = load_model(
            model_path,
            cache_dir=model_path,
            torch_dtype=model.torch_dtype
        )
        model.tokenizer = loaded_model["tokenizer"]
        model.tokenizer.pad_token = model.tokenizer.eos_token
        model.model = loaded_model["model"]

        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model.model = model.model.to(device)
        # Check if multiple GPUs are available
        if torch.cuda.device_count() > 1:
            logging.info(f"Let's use {torch.cuda.device_count()} GPUs!")
        model.model.eval()

    except Exception as err:
        # Print memory summary for each GPU in case of an error
        for i in range(torch.cuda.device_count()):
            logging.info(f"Memory summary for GPU {i}:")
            logging.info(torch.cuda.memory_summary(device=i))
        raise err
        
    #build_index(database_dir, "/hub/huggingface/models/bert/bert-base-uncased", index_name=database_dir.replace('/database.json', '/index.faiss'))
    
    # 使用LoRA增强RAG模型
    index = faiss.read_index(database_dir.replace('/database.json', '/index.faiss'))
    emb_model = AutoModel.from_pretrained("/hub/huggingface/models/bert/bert-base-uncased")
    tokenizer = AutoTokenizer.from_pretrained("/hub/huggingface/models/bert/bert-base-uncased")
    
    documents = load_json_file(database_dir, "context")

    lora_model = LoRARAG(model.model, model.tokenizer, rank=8, index=index, emb_model=emb_model, tokenizer=tokenizer, documents=documents)
    
    # 加载数据集
    dataset = SafetyDataset(data_path="/guardrail/TaskTracker/store/output_datasets/Reconnaissance/hotpotqa/merged_dataset_train.json", tokenizer=model.tokenizer)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

    # for name, p in lora_model.named_parameters():
    #     #print(p)
    #     print(name)
    
    # 设置优化器，LoRA参数和其他部分参数需要不同的学习率
    lora_params = []
    base_params = []
    for name, param in lora_model.named_parameters():
        if 'lora' in name:
            lora_params.append(param)
        else:
            base_params.append(param)

    optimizer = torch.optim.AdamW([
        {'params': base_params, 'lr': 5e-5},
        {'params': lora_params, 'lr': 1e-4}
    ])
    
    # 训练模型
    train_model(lora_model, dataloader, optimizer)

# 执行训练
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process activations for a specified model.")
    parser.add_argument("--model_name", type=str, required=True, help="Name of the model to use")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the model files")
    args = parser.parse_args()

    main(args.model_name, args.model_path)
