from transformers import AutoTokenizer,AutoModelForCausalLM, pipeline,BitsAndBytesConfig, Trainer, TrainingArguments,DataCollatorWithPadding,TrainerCallback
from peft import LoraConfig
from torch.utils.data import DataLoader, TensorDataset
import torch
import pandas as pd
import numpy as np
from huggingface_hub import login
import gc
import torch.nn as nn
from peft import LoraConfig, get_peft_model
import sys

sys.setrecursionlimit(35000)
gc.collect()
torch.cuda.empty_cache()

#login(token='') #please enter your hugging face token and remove the '#'
base_model_path = 'meta-llama/Llama-3.1-8B' # you can use any LLM
compute_dtype = getattr(torch, "float16")
print(torch.cuda.is_available()) # if false then make sure you have GPU

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype
)
model = AutoModelForCausalLM.from_pretrained(base_model_path,
                                             quantization_config=bnb_config,
                                             device_map="auto")

tokenizer = AutoTokenizer.from_pretrained(base_model_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
target_modules = ['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"]

peft_config = LoraConfig(
        lora_alpha=128,
        lora_dropout=0.10,
        target_modules=target_modules,
        r=32,
        bias="none",
        task_type="CAUSAL_LM",
        inference_mode=False,
)
model = get_peft_model(model, peft_config)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Here we define structural and temporal tokenizers
class StructuralTokenizer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)

    def forward(self, X_s):
        H_s = torch.relu(self.fc1(X_s))
        Z_s = self.fc2(H_s)
        return self.layer_norm(Z_s)

class TemporalTokenizer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)

    def forward(self, X_t):
        H_t = torch.relu(self.fc1(X_t))
        Z_t = self.fc2(H_t)
        return self.layer_norm(Z_t)
d_LLM=4096 # change it based on the dimensions of the model you choose
n_eigen = 6
structural_tokenizer = StructuralTokenizer(input_dim=n_eigen, hidden_dim=d_LLM, output_dim=d_LLM).to(device)
temporal_tokenizer = TemporalTokenizer(d_LLM, d_LLM, d_LLM).to(device)

# Preprocess Structural, Temporal, and Prompt embeddings
def preprocess_tokens(X_profile, adj_matrix, prompt_text, d_LLM, m, model, tokenizer, device):
    # Structural Embedding
    def compute_structural_embedding(adj_matrix, embed_dim):
        import networkx as nx
        from scipy.linalg import eigh
        G = nx.from_numpy_array(adj_matrix)
        L = nx.normalized_laplacian_matrix(G).toarray()
        _, eigenvectors = eigh(L)
        return eigenvectors[:, :embed_dim]

    with torch.no_grad():
        n_eigen = 6
        raw_structural = compute_structural_embedding(adj_matrix, n_eigen)
        raw_structural = torch.tensor(raw_structural, dtype=torch.float32).to(device)  # Shape: (n, d_LLM)

    structural_tokens = structural_tokenizer(raw_structural) # structural embedding is fed to the structural tokenizer defined earlier

    # Temporal Embedding
    class TemporalEmbedding(nn.Module):
        def __init__(self, input_dim, embed_dim):
            super().__init__()
            self.W_Q = nn.Linear(input_dim, embed_dim)
            self.W_K = nn.Linear(input_dim, embed_dim)
            self.W_V = nn.Linear(input_dim, embed_dim)

        def forward(self, x):
            Q = self.W_Q(x)
            K = self.W_K(x)
            V = self.W_V(x)
            scores = torch.softmax(torch.bmm(Q.unsqueeze(0), K.unsqueeze(0).transpose(1, 2)) / np.sqrt(K.shape[-1]), dim=-1)
            return torch.bmm(scores, V.unsqueeze(0)).squeeze(0)

    temporal_embedder = TemporalEmbedding(input_dim=m, embed_dim=d_LLM).to(device)
    with torch.no_grad():
        temporal_input = X_profile.to(device)
        temporal_output = temporal_embedder(temporal_input)  # Shape: (T, d_LLM)

    temporal_tokens = temporal_tokenizer(temporal_output) # temporal embedding is fed to the temporal tokenizer defined earlier

    # Prompt Tokens
    prompt_ids = tokenizer(prompt_text, return_tensors='pt', padding=True, truncation=True).input_ids.to(device) # pretrained tokenizer
    prompt_embeds = model.get_input_embeddings()(prompt_ids).squeeze(0)

    return structural_tokens, temporal_tokens, prompt_embeds

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, inputs_embeds, labels):
        self.inputs_embeds = inputs_embeds
        self.labels = labels
    def __len__(self):
        return len(self.inputs_embeds)
    def __getitem__(self, idx):
        return {
            "inputs_embeds": self.inputs_embeds[idx],
            "labels": self.labels[idx],
        }

class CustomTrainer(Trainer):
    def get_train_dataloader(self):
        train_dataloader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            shuffle=True,
            collate_fn=self.data_collator,
            pin_memory=False,
        )
        return train_dataloader

    def get_eval_dataloader(self, eval_dataset=None):
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        eval_dataloader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator,
            pin_memory=False,
        )
        return eval_dataloader

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        inputs_embeds = inputs["inputs_embeds"]
        labels = inputs["labels"]
        # Forward pass
        outputs = model(inputs_embeds=inputs_embeds, labels=labels)
        loss = outputs.loss
        return (loss, outputs) if return_outputs else loss


def custom_collator(features):
  inputs_embeds = torch.stack([f["inputs_embeds"] for f in features])
  labels = torch.stack([f["labels"] for f in features])
  inputs_embeds = inputs_embeds.clone().detach().requires_grad_(True)
  labels = labels.to(device)
  return {"inputs_embeds": inputs_embeds, "labels": labels}

output_dir = 'LLAMA3_finetuneopt_male_Twothousand' # the model will be saved here
file_path = "dummy_data/training_data.csv"
eval_file_path = "dummy_data/evaluation_data.csv"
features = ['Adiol', 'Bdiol', 'A', 'Etio', 'E', 'T']
adj_matrix = np.array([
    [1, 0, 0, 0, 0, 1],
    [0, 1, 1, 1, 0, 1],
    [1, 0, 1, 0, 0, 0],
    [0, 1, 0, 1, 0, 0],
    [0, 0, 1, 0, 1, 0],
    [1, 1, 1, 0, 0, 1]
])
n = adj_matrix.shape[0]  # Number of nodes

# We load data in chunks because memory cannot load all profiles at once if we finetune on a large number of profiles
def train_on_chunk(chunk, eval_chunk, current_epoch, normal_sample):
    chunk = chunk[['ID'] + features]
    grouped = chunk.groupby('ID')
    profiles = [group.drop('ID', axis=1).to_numpy(dtype=np.float32) for _, group in grouped] # samples sharing same ID belong to same profile
    n_max = max(profile.shape[0] for profile in profiles)
    padded_profiles = np.array([
        np.pad(profile, ((0, n_max - profile.shape[0]), (0, 0)), mode='constant', constant_values=0)
        for profile in profiles
    ])
    mask = (padded_profiles != 0).any(axis=-1).astype(float)
    X = torch.tensor(padded_profiles, dtype=torch.float32)
    mask = torch.tensor(mask, dtype=torch.float32)
    gc.collect()
    torch.cuda.empty_cache()
    combined_tokens_list = []

    if normal_sample: # first train on normal samples using this prompt
        prompt_text = "Analyse the given longitudinal profile of male athletes, consisting of multiple steroid metabolites measured across different time points. All samples are normal. Confirm consistency."

    else: # then we train on samples with anomalies using this prompt
        prompt_text = "Analyse the given longitudinal profile of male athletes, consisting of multiple steroid metabolites measured across different time points. The last sample is an anomaly. Identify and learn why is it anomaly."

    for profile, profile_mask in zip(X, mask):
        unmasked_profile = profile[profile_mask.bool()]
        structural_tokens, temporal_tokens, prompt_tokens = preprocess_tokens(
            unmasked_profile, adj_matrix, prompt_text, d_LLM, len(features), model, tokenizer, device
        )
        structural_tokens = structural_tokens.to(device)
        temporal_tokens = temporal_tokens.to(device)
        prompt_tokens = prompt_tokens.to(device)

        combined_tokens = torch.cat([structural_tokens, temporal_tokens, prompt_tokens], dim=0)
        combined_tokens_list.append(combined_tokens)
    # Pad all combined tokens to the same sequence length
    max_len = max(t.shape[0] for t in combined_tokens_list)
    for i in range(len(combined_tokens_list)):
        seq_len = combined_tokens_list[i].shape[0]
        if seq_len < max_len:
            pad_len = max_len - seq_len
            pad_tensor = torch.zeros((pad_len, d_LLM), dtype=combined_tokens_list[i].dtype, device=combined_tokens_list[i].device)
            combined_tokens_list[i] = torch.cat([combined_tokens_list[i], pad_tensor], dim=0)
    combined_tokens = torch.stack(combined_tokens_list).to(torch.float16).to(device)
    print("Combined Token Dimensions:", combined_tokens.shape)
    gc.collect()
    torch.cuda.empty_cache()
    labels = torch.ones(combined_tokens.size()[:-1], dtype=torch.long).to(device) * -100
    labels[:, :-1] = torch.arange(combined_tokens.size(1) - 1).unsqueeze(0).repeat(combined_tokens.size(0), 1).to(device)
    combined_tokens = combined_tokens.to(torch.float16)

    train_dataset = CustomDataset(combined_tokens, labels)

    print("Inputs Embeds Shape:", combined_tokens.shape)  # Should be (batch_size, seq_len, embedding_dim)
    print("Labels Shape:", labels.shape)

    trainer.train_dataset = train_dataset
    train_results = trainer.train()
    print(f"Training Results: {train_results.metrics}")

    # clean to free some memory
    del chunk
    del padded_profiles, mask
    del structural_tokens, temporal_tokens, prompt_tokens
    del combined_tokens, labels, train_dataset
    gc.collect()
    torch.cuda.empty_cache()


# then do the same for evaluation
def evaluate_on_chunk(chunk, trainer, normal_sample):
    with torch.no_grad():
        chunk = chunk[['ID'] + features]
        grouped = chunk.groupby('ID')
        profiles = [group.drop('ID', axis=1).to_numpy(dtype=np.float32) for _, group in grouped]
        n_max = max(profile.shape[0] for profile in profiles)
        padded_profiles = np.array([
            np.pad(profile, ((0, n_max - profile.shape[0]), (0, 0)), mode='constant', constant_values=0)
            for profile in profiles
        ])
        mask = (padded_profiles != 0).any(axis=-1).astype(float)
        # Convert to tensors
        X = torch.tensor(padded_profiles, dtype=torch.float32)
        mask = torch.tensor(mask, dtype=torch.float32)

        combined_tokens_list = []
        if normal_sample:
            prompt_text = "Analyse the given longitudinal profile of male athletes, consisting of multiple steroid metabolites measured across different time points. All samples are normal. Confirm consistency."

        else:
            prompt_text = "Analyse the given longitudinal profile of male athletes, consisting of multiple steroid metabolites measured across different time points. The last sample is an anomaly. Identify and learn why is it anomaly."

        for profile, profile_mask in zip(X, mask):
            unmasked_profile = profile[profile_mask.bool()]
            structural_tokens, temporal_tokens, prompt_tokens = preprocess_tokens(
                unmasked_profile, adj_matrix, prompt_text, d_LLM, len(features), model, tokenizer, device
            )
            structural_tokens = structural_tokens.detach().to(device)
            temporal_tokens = temporal_tokens.detach().to(device)
            prompt_tokens = prompt_tokens.to(device)
            combined_tokens = torch.cat([structural_tokens, temporal_tokens, prompt_tokens], dim=0)
            combined_tokens_list.append(combined_tokens)
        max_len = max(t.shape[0] for t in combined_tokens_list)
        for i in range(len(combined_tokens_list)):
            seq_len = combined_tokens_list[i].shape[0]
            if seq_len < max_len:
                pad_len = max_len - seq_len
                pad_tensor = torch.zeros((pad_len, d_LLM), dtype=combined_tokens_list[i].dtype, device=combined_tokens_list[i].device)
                combined_tokens_list[i] = torch.cat([combined_tokens_list[i], pad_tensor], dim=0)

        combined_tokens = torch.stack(combined_tokens_list).to(torch.float16).to(device)
        print("eval Combined Token Dimensions:", combined_tokens.shape)
        gc.collect()
        torch.cuda.empty_cache()
        labels = torch.ones(combined_tokens.size()[:-1], dtype=torch.long).to(device) * -100
        labels[:, :-1] = torch.arange(combined_tokens.size(1) - 1).unsqueeze(0).repeat(combined_tokens.size(0), 1).to(device)
        combined_tokens = combined_tokens.to(torch.float16)

        eval_dataset = CustomDataset(combined_tokens, labels)
        trainer.eval_dataset = eval_dataset
        results = trainer.evaluate()
        print(f"Evaluation Results: {results}")
    del chunk
    del padded_profiles, mask
    del structural_tokens, temporal_tokens, prompt_tokens
    del combined_tokens, labels, eval_dataset
    gc.collect()
    torch.cuda.empty_cache()
    return results


args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    save_strategy="epoch",
    logging_steps=10,
    logging_dir='./logs',
    learning_rate=2e-5,
    warmup_steps=500,
    save_total_limit=2,
    remove_unused_columns=False,
    fp16=True,
    report_to=[],
)

trainer = CustomTrainer(
    model=model,
    args=args,
    tokenizer=None,
    data_collator=custom_collator,
)

# Main training loop
NUM_EPOCHS = 10
global_epoch = 0

for epoch in range(NUM_EPOCHS):
    print(f"Starting epoch {epoch + 1}/{NUM_EPOCHS}...")
    header= pd.read_csv(file_path, nrows=0).columns
    # First, train on the normal profiles
    train_chunks = pd.read_csv(file_path, chunksize=15, nrows=15) # dummy numbers
    eval_chunks = pd.read_csv(eval_file_path, chunksize=10, nrows=10) # dummy numbers
    
    print("Training on normal profiles...")
    for train_chunk, eval_chunk in zip(train_chunks, eval_chunks):
        print(f"Processing normal chunk for epoch {epoch + 1}...")
        train_on_chunk(train_chunk, trainer, current_epoch=global_epoch, normal_sample=True)
        evaluate_on_chunk(eval_chunk, trainer, normal_sample=True)
        del train_chunk, eval_chunk
        gc.collect()
        torch.cuda.empty_cache()

    # Now, train on the anomaly profiles
    train_chunks = pd.read_csv(file_path, chunksize=15, nrows=15) # dummy numbers
    eval_chunks = pd.read_csv(eval_file_path, chunksize=10, nrows=10) # dummy numbers
    
    print("Training on anomaly profiles...")
    for train_chunk, eval_chunk in zip(train_chunks, eval_chunks):
        print(f"Processing anomaly chunk for epoch {epoch + 1}...")
        train_on_chunk(train_chunk, trainer, current_epoch=global_epoch, normal_sample=False)
        evaluate_on_chunk(eval_chunk, trainer, normal_sample=False)
        del train_chunk, eval_chunk
        gc.collect()
        torch.cuda.empty_cache()

    del train_chunks, eval_chunks
    gc.collect()
    torch.cuda.empty_cache()
    global_epoch += 1

    print(f"Finished epoch {epoch + 1}/{NUM_EPOCHS}")
    gc.collect()
    torch.cuda.empty_cache()

model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
