from transformers import AutoTokenizer,AutoModelForCausalLM, pipeline,BitsAndBytesConfig, Trainer, TrainingArguments,DataCollatorWithPadding,TrainerCallback
from peft import LoraConfig
from torch.utils.data import DataLoader, TensorDataset
import torch
import pandas as pd
import numpy as np
from huggingface_hub import login
import gc
import torch.nn as nn
from peft import LoraConfig, get_peft_model
import sys

sys.setrecursionlimit(35000)

gc.collect()
torch.cuda.empty_cache()

#login(token='') # Enter your huggingface token and remove the '#'
base_model_path = 'meta-llama/Llama-3.1-8B-Instruct'
compute_dtype = getattr(torch, "float16")
print(torch.cuda.is_available()) # if false then make sure you have GPU

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype
)
model = AutoModelForCausalLM.from_pretrained(base_model_path,
                                             quantization_config=bnb_config,
                                             device_map="auto")

tokenizer = AutoTokenizer.from_pretrained(base_model_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

target_modules = ['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"]

peft_config = LoraConfig(
        lora_alpha=128,
        lora_dropout=0.10,
        target_modules=target_modules,
        r=32,
        bias="none",
        task_type="CAUSAL_LM",
        inference_mode=False,
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()  # Verify trainable parameters
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.__version__)
print(torch.cuda.is_available())

# Define structural and temporal tokenizers
class StructuralTokenizer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)

    def forward(self, X_s):
        H_s = torch.relu(self.fc1(X_s))
        Z_s = self.fc2(H_s)
        return self.layer_norm(Z_s)

class TemporalTokenizer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)

    def forward(self, X_t):
        H_t = torch.relu(self.fc1(X_t))
        Z_t = self.fc2(H_t)
        return self.layer_norm(Z_t)
d_LLM=4096
n_eigen = 6
structural_tokenizer = StructuralTokenizer(input_dim=n_eigen, hidden_dim=d_LLM, output_dim=d_LLM).to(device)
temporal_tokenizer = TemporalTokenizer(d_LLM, d_LLM, d_LLM).to(device)

# Preprocess Structural, Temporal, and Prompt embeddings
def preprocess_tokens(X_profile, adj_matrix, prompt_text, d_LLM, m, model, tokenizer, device):
    # Structural Embedding
    def compute_structural_embedding(adj_matrix, embed_dim):
        import networkx as nx
        from scipy.linalg import eigh
        G = nx.from_numpy_array(adj_matrix)
        L = nx.normalized_laplacian_matrix(G).toarray()
        _, eigenvectors = eigh(L)
        return eigenvectors[:, :embed_dim]

    with torch.no_grad():
        n_eigen = 6
        raw_structural = compute_structural_embedding(adj_matrix, n_eigen)
        raw_structural = torch.tensor(raw_structural, dtype=torch.float32).to(device)  # Shape: (n, d_LLM)

    structural_tokens = structural_tokenizer(raw_structural) # structural embedding is fed to the structural tokenizer defined earlier

    # Temporal Embedding
    class TemporalEmbedding(nn.Module):
        def __init__(self, input_dim, embed_dim):
            super().__init__()
            self.W_Q = nn.Linear(input_dim, embed_dim)
            self.W_K = nn.Linear(input_dim, embed_dim)
            self.W_V = nn.Linear(input_dim, embed_dim)

        def forward(self, x):
            Q = self.W_Q(x)
            K = self.W_K(x)
            V = self.W_V(x)
            scores = torch.softmax(torch.bmm(Q.unsqueeze(0), K.unsqueeze(0).transpose(1, 2)) / np.sqrt(K.shape[-1]), dim=-1)
            return torch.bmm(scores, V.unsqueeze(0)).squeeze(0)

    temporal_embedder = TemporalEmbedding(input_dim=m, embed_dim=d_LLM).to(device)
    with torch.no_grad():
        temporal_input = X_profile.to(device)
        temporal_output = temporal_embedder(temporal_input)  # Shape: (T, d_LLM)

    temporal_tokens = temporal_tokenizer(temporal_output) # temporal embedding is fed to the temporal tokenizer defined earlier

    # Prompt Tokens
    prompt_ids = tokenizer(prompt_text, return_tensors='pt', padding=True, truncation=True).input_ids.to(device) # pretrained tokenizer
    prompt_embeds = model.get_input_embeddings()(prompt_ids).squeeze(0)

    return structural_tokens, temporal_tokens, prompt_embeds

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, inputs_embeds, labels):
        self.inputs_embeds = inputs_embeds
        self.labels = labels
    def __len__(self):
        return len(self.inputs_embeds)
    def __getitem__(self, idx):
        return {
            "inputs_embeds": self.inputs_embeds[idx],
            "labels": self.labels[idx],
        }

class CustomTrainer(Trainer):
    def get_train_dataloader(self):
        train_dataloader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            shuffle=True,
            collate_fn=self.data_collator,
            pin_memory=False,
        )
        return train_dataloader

    def get_eval_dataloader(self, eval_dataset=None):
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        eval_dataloader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator,
            pin_memory=False,
        )
        return eval_dataloader

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        inputs_embeds = inputs["inputs_embeds"]
        labels = inputs["labels"]
        # Forward pass
        outputs = model(inputs_embeds=inputs_embeds, labels=labels)
        loss = outputs.loss
        return (loss, outputs) if return_outputs else loss


def custom_collator(features):

  inputs_embeds = torch.stack([f["inputs_embeds"] for f in features])
  labels = torch.stack([f["labels"] for f in features])
  inputs_embeds = inputs_embeds.clone().detach().requires_grad_(True)
  labels = labels.to(device)
  return {"inputs_embeds": inputs_embeds, "labels": labels}

output_dir = 'LLAMA3_finetuneexp_male' # the model will be saved here
file_path = "dummy_data/reasoning_data.csv"
features = ['Adiol', 'Bdiol', 'A', 'Etio', 'E', 'T']
adj_matrix = np.array([
    [1, 0, 0, 0, 0, 1],
    [0, 1, 1, 1, 0, 1],
    [1, 0, 1, 0, 0, 0],
    [0, 1, 0, 1, 0, 0],
    [0, 0, 1, 0, 1, 0],
    [1, 1, 1, 0, 0, 1]
])
n = adj_matrix.shape[0] # Number of nodes

# We load data in chunks because memory cannot load all profiles at once if we finetune on a large number of profiles
def train_on_chunk(chunk, eval_chunk, current_epoch):
    chunk = chunk[['ID'] + features + ['Explanation']]
    grouped = chunk.groupby('ID') # samples sharing same ID belong to same profile
    profiles = [group.drop(['ID', 'Explanation'], axis=1).to_numpy(dtype=np.float32) for _, group in grouped]
    n_max = max(profile.shape[0] for profile in profiles)
    padded_profiles = np.array([
        np.pad(profile, ((0, n_max - profile.shape[0]), (0, 0)), mode='constant', constant_values=0)
        for profile in profiles
    ])
    mask = (padded_profiles != 0).any(axis=-1).astype(float)
    X = torch.tensor(padded_profiles, dtype=torch.float32)
    mask = torch.tensor(mask, dtype=torch.float32)
    gc.collect()
    torch.cuda.empty_cache()
    combined_tokens_list = []
    # prompt
    prompt_text = "Analyse the given longitudinal profile of male athletes, consisting of multiple steroid metabolites measured across different time points. If any sample is abnormal, explain why."

    for profile, profile_mask in zip(X, mask):
        unmasked_profile = profile[profile_mask.bool()]
        structural_tokens, temporal_tokens, prompt_tokens = preprocess_tokens(
            unmasked_profile, adj_matrix, prompt_text, d_LLM, len(features), model, tokenizer, device
        )
        combined_tokens = torch.cat([ structural_tokens, temporal_tokens, prompt_tokens], dim=0)
        combined_tokens_list.append(combined_tokens)
    max_tokens_len = max(tokens.shape[0] for tokens in combined_tokens_list)
    padded_combined_tokens_list = [
        torch.cat([tokens, torch.zeros(max_tokens_len - tokens.shape[0], tokens.shape[1], dtype=tokens.dtype, device=tokens.device)], dim=0)
        for tokens in combined_tokens_list
    ]
    combined_tokens = torch.stack(padded_combined_tokens_list).to(torch.float16).to(device)
    explanations = chunk.groupby("ID")["Explanation"].last().tolist()

    labels = []
    for explanation in explanations:
        if isinstance(explanation, str) and explanation.strip():
            target_ids = tokenizer(
                explanation,
                return_tensors="pt",
                truncation=True,
                padding="max_length",
                max_length=combined_tokens.size(1)
            ).input_ids[0]
        else:
            target_ids = torch.ones(combined_tokens.size(1), dtype=torch.long) * -100

        labels.append(target_ids)

    labels = torch.stack(labels).to(device)

    train_dataset = CustomDataset(combined_tokens, labels)

    print("Inputs Embeds Shape:", combined_tokens.shape)
    print("Labels Shape:", labels.shape)

    trainer.train_dataset = train_dataset
    train_results = trainer.train()
    print(f"Training Results: {train_results.metrics}")

    del chunk
    del padded_profiles, mask
    del structural_tokens, temporal_tokens, prompt_tokens
    del combined_tokens, labels, train_dataset
    gc.collect()
    torch.cuda.empty_cache()


args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    save_strategy="epoch",
    logging_steps=10,
    logging_dir='./logs',
    learning_rate=2e-5,
    warmup_steps=500,
    save_total_limit=2,
    remove_unused_columns=False,
    fp16=True,
    report_to=[],
)

trainer = CustomTrainer(
    model=model,
    args=args,
    tokenizer=None,
    data_collator=custom_collator,
)

# Main training loop
NUM_EPOCHS = 10
global_epoch = 0

for epoch in range(NUM_EPOCHS):
    print(f"Starting epoch {epoch + 1}/{NUM_EPOCHS}...")
    header = pd.read_csv(file_path, nrows=0).columns
    train_chunks = pd.read_csv(file_path, chunksize=25, nrows=50) # dummy numbers
    for train_chunk in train_chunks:
        print(f"Processing normal chunk for epoch {epoch + 1}...")
        train_on_chunk(train_chunk, trainer, current_epoch=global_epoch)
        del train_chunk
        gc.collect()
        torch.cuda.empty_cache()

    del train_chunks
    gc.collect()
    torch.cuda.empty_cache()
    global_epoch += 1


    print(f"Finished epoch {epoch + 1}/{NUM_EPOCHS}")
    gc.collect()
    torch.cuda.empty_cache()

# Final model save
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
