import os
from CARD import CARD
from alphabet import Alphabet
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import random
from util import CoordBatchConverter, load_structure, extract_coords_from_structure, seq_rec_rate

import argparse
import csv  # 导入csv模块
from transformers import AutoTokenizer, AutoModel
from Bio.PDB import PDBParser, PPBuilder, Select, is_aa
from complex_datasets import *
random.seed(1)

def get_amino_acid_centers(chain):
    pp_builder = PPBuilder()
    amino_acid_centers = []
    for pp in pp_builder.build_peptides(chain):
        for res in pp:
            if 'CA' not in res or not is_aa(res.get_resname(), standard=True):
                continue
            amino_acid_centers.append(res['CA'].get_coord())
    return np.array(amino_acid_centers)

def get_rna_atoms(structure, rna_chains):
    rna_atoms = []
    for chain in structure[0]:
        if chain.id in rna_chains:
            for residue in chain:
                if residue.id[0] == ' ':
                    for atom in residue:
                        rna_atoms.append(atom.coord)
    return np.array(rna_atoms)

def find_k_closest_amino_acids(amino_acid_centers, rna_atoms, k):
    distances = []
    for i, center in enumerate(amino_acid_centers):
        min_distance = np.min(np.linalg.norm(rna_atoms - center, axis=1))
        distances.append((i, min_distance))
    distances = sorted(distances, key=lambda x: x[1])  # 按距离排序
    closest_indices = [idx for idx, _ in distances[:k]]
    return closest_indices

def extract_protein_sequence(pdb_path, chains):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_path)
    chains = chains.split(",")

    pp_builder = PPBuilder()
    protein_sequence = ""

    for chain in structure[0]:
        if chain.id in chains:
            for pp in pp_builder.build_peptides(chain):
                protein_sequence += pp.get_sequence()

    return str(protein_sequence)

def encode_protein_sequence(protein_sequence):
    esm2_model_name = "/path/to/esm2"
    tokenizer = AutoTokenizer.from_pretrained(esm2_model_name)
    esm2_model = AutoModel.from_pretrained(esm2_model_name).cuda()
    inputs = tokenizer(protein_sequence, return_tensors='pt', padding=True, truncation=True)
    inputs = inputs.to('cuda')
    with torch.no_grad():
        outputs = esm2_model(**inputs)
        protein_embedding = outputs.last_hidden_state[:,1:-1]  # [batch_size, hidden_size]
    return protein_embedding

class args_class:
    def __init__(self, encoder_embed_dim, decoder_embed_dim, dropout):
        self.local_rank = int(os.getenv("LOCAL_RANK", -1))
        self.device_id = 0
        self.epochs = 100
        self.lr = 2e-4
        self.batch_size = 24
        self.encoder_embed_dim = encoder_embed_dim
        self.decoder_embed_dim = decoder_embed_dim
        self.dropout = dropout
        self.gvp_top_k_neighbors = 15
        self.gvp_node_hidden_dim_vector = 256
        self.gvp_node_hidden_dim_scalar = 512
        self.gvp_edge_hidden_dim_scalar = 32
        self.gvp_edge_hidden_dim_vector = 1
        self.gvp_num_encoder_layers = 3
        self.gvp_dropout = 0.1
        self.encoder_layers = 3
        self.encoder_attention_heads = 4
        self.attention_dropout = 0.1
        self.encoder_ffn_embed_dim = 512
        self.decoder_layers = 3
        self.decoder_attention_heads = 4
        self.decoder_ffn_embed_dim = 512
        self.attn_layer = 2
        self.protein_len = 65

def inference(model, coords, device, global_embedding, local_embedding, temp=1.0):
    model.eval()
    model.to(device)

    with torch.no_grad():
        converter = CoordBatchConverter(Alphabet(['A', 'G', 'C', 'U', 'X']))
        batch_coords, confidence, _, _, padding_mask, _ = converter(
            [(coords, None, None, None)]
        )

        batch_coords, confidence, padding_mask = (
            batch_coords.cuda(device=device),
            confidence.cuda(device=device),
            padding_mask.bool().cuda(device=device)
        )
        global_embedding, local_embedding = global_embedding.cuda(device=device),local_embedding.cuda(device=device)
        protein_mask = torch.ones(1, 64).bool().cuda(device=device)
        logits, rna_emb = model(batch_coords, confidence, 
                    padding_mask, local_embedding, 
                    global_embedding, protein_mask, 
                    use_protein=True, need_embedding=True, temperature=temp)
        
        pred_tokens = logits.argmax(dim=-1)
        pred_tokens = pred_tokens.view(-1)
        pred_seq = [model.decoder.dictionary.get_tok(t.item()) for t in pred_tokens]
        pred_seq = ''.join(pred_seq)

    return pred_seq

def get_esm2(pdb_path, protein_chains, rna_chains):
    if not os.path.exists(pdb_path):
        print(f"警告：PDB 文件 {pdb_path} 不存在，跳过...")
        return 
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('complex', pdb_path)
    rna_atoms = get_rna_atoms(structure, rna_chains)
    if rna_atoms.size == 0:
        print(f"警告：未找到RNA原子，跳过 {pdb_path}...")
        return
    chain = structure[0][protein_chains]
    amino_acid_centers = get_amino_acid_centers(chain)
    closest_indices = find_k_closest_amino_acids(amino_acid_centers, rna_atoms, 64)
    protein_sequence = extract_protein_sequence(pdb_path, protein_chains)
    if amino_acid_centers.shape[0] != len(protein_sequence):
        print(amino_acid_centers.shape[0])
        print(len(protein_sequence))
        print(protein_sequence)
        return
    
    local_embedding = encode_protein_sequence(protein_sequence)
    global_embedding = local_embedding.mean(dim=1).unsqueeze(1)
    local_embedding = local_embedding[:, closest_indices]
    return global_embedding,local_embedding

def eval(model, pdb_path, save_path, _device, protein_chain, rna_chain, num_samples=1000, temp=1000):
    
    model_path = 'main_layer6.pth'
    name = pdb_path[-5]
    
    model_dir = torch.load(model_path) 
    model.load_state_dict(model_dir)
    model.eval()
    
    pdb = load_structure(pdb_path, rna_chain)
    pdb = filter_standard_bases(pdb)
    coords, seq = extract_coords_from_structure(pdb)
    global_embedding, local_embedding = get_esm2(pdb_path, protein_chains=protein_chain, rna_chains=rna_chain)
    pred_seqs = []
    for i in tqdm(range(num_samples), desc=f'Generating {num_samples} sequences'):
        pred_seq = inference(model, coords, _device, global_embedding, local_embedding, temp)
        pred_seqs.append(pred_seq)
    return pred_seqs

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Description of your script')
    parser.add_argument('-pdb', '--pdb_dir', type=str, help='path to the pdb file')
    parser.add_argument('-chain', '--chain', default='A', type=str, help='chain ID of protein')
    parser.add_argument('-rna_chain', '--rna_chain', default='B', type=str, help='chain ID of RNA')
    parser.add_argument('-save', '--save_path', default='./example/round1.csv',type=str, help='path to the save directory')
    parser.add_argument('-device', '--device', default=0, type=int, help='Assign the device to run the model')
    parser.add_argument('-temp', '--temperature', default=5, type=float, help='temperature for sampling')
    parser.add_argument('-num_samples', '--num_samples', default=1000, type=int, help='Number of RNA sequences to generate')
    args = parser.parse_args()
    
    pdb_dir = args.pdb_dir
    save_path = args.save_path
    _device = args.device
    temp = args.temperature
    num_samples = args.num_samples
    protein_chain = args.chain
    rna_chain = args.rna_chain

    model_args = args_class(512, 512, 0.1)
    dictionary = Alphabet(['A', 'G', 'C', 'U', 'X'])
    model = CARD(model_args, dictionary).cuda(device=_device)
    pdb_files = [f for f in os.listdir(pdb_dir) if f.endswith('.pdb')]
    num = len(pdb_files)
    print(f"Number of .pdb files in {pdb_dir}: {num}")

    all_seqs = []
    for pdb in pdb_files:
        pdb = os.path.join(pdb_dir, pdb)
                              
        pred_seqs = eval(model, pdb, save_path, _device, rna_chain=rna_chain, protein_chain=protein_chain, num_samples=num_samples//num, temp=temp)
        all_seqs.extend(pred_seqs)
    all_seqs = list(set(all_seqs))

    csv_file = save_path
    with open(csv_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        for seq in all_seqs:
            writer.writerow([seq])

    print(f'Generated {num_samples} sequences and saved to {csv_file}')