from transformers import AutoTokenizer,AutoModelForCausalLM, pipeline,BitsAndBytesConfig, Trainer, TrainingArguments,DataCollatorWithPadding,TrainerCallback
from transformers import AutoModel
from peft import LoraConfig
from torch.utils.data import DataLoader, TensorDataset
import torch
import pandas as pd
import numpy as np
from huggingface_hub import login
import gc
import torch.nn as nn
from peft import LoraConfig, get_peft_model
import sys

sys.setrecursionlimit(35000)

gc.collect()
torch.cuda.empty_cache()

#login(token='') # enter your huggingface token and remove the '#'
output_dir = 'LLAMA3_opt_pred_male_Twothousand' # the model will be saved here
file_path = "dummy_data/training_data.csv"
eval_file_path = "dummy_data/evaluation_data.csv"
base_model_path = 'meta-llama/Llama-3.1-8B' # you can use any LLM
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define structural and temporal tokenizers
class StructuralTokenizer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)

    def forward(self, X_s):
        H_s = torch.relu(self.fc1(X_s))
        Z_s = self.fc2(H_s)
        return self.layer_norm(Z_s)

class TemporalTokenizer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)

    def forward(self, X_t):
        H_t = torch.relu(self.fc1(X_t))
        Z_t = self.fc2(H_t)
        return self.layer_norm(Z_t)
d_LLM=4096 # change based on the LLM you're using
n_eigen = 6
structural_tokenizer = StructuralTokenizer(input_dim=n_eigen, hidden_dim=d_LLM, output_dim=d_LLM).to(device)
temporal_tokenizer = TemporalTokenizer(d_LLM, d_LLM, d_LLM).to(device)

compute_dtype = getattr(torch, "float16")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype
)

class RMSELoss(nn.Module):
    def __init__(self, eps=1e-8):
        super().__init__()
        self.eps = eps

    def forward(self, y_pred, y_true):
        return torch.sqrt(torch.mean((y_pred - y_true) ** 2) + self.eps)

class SteroidRegressor(nn.Module):
    def __init__(self, base_model, hidden_size, output_size=6):
        super().__init__()
        self.base_model = base_model
        self.regressor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, output_size)
        )

    def forward(self, inputs_embeds=None, attention_mask=None, labels=None):
        base_outputs = self.base_model(inputs_embeds=inputs_embeds)
        pooled = base_outputs.last_hidden_state.mean(dim=1)
        predictions = self.regressor(pooled)

        loss = None
        if labels is not None:
            loss_fn = RMSELoss()
            loss = loss_fn(predictions, labels)

        return {"predictions": predictions, "loss": loss} if loss is not None else {"predictions": predictions}


    def gradient_checkpointing_enable(self, **kwargs):
        if hasattr(self.base_model, "gradient_checkpointing_enable"):
            self.base_model.gradient_checkpointing_enable(**kwargs)

    def gradient_checkpointing_disable(self, **kwargs):
        if hasattr(self.base_model, "gradient_checkpointing_disable"):
            self.base_model.gradient_checkpointing_disable(**kwargs)


base_model = AutoModel.from_pretrained(base_model_path, quantization_config=bnb_config, device_map="auto")
model = SteroidRegressor(base_model, hidden_size=d_LLM)

tokenizer = AutoTokenizer.from_pretrained(base_model_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

target_modules = ['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"]

peft_config = LoraConfig(
        lora_alpha=128,
        lora_dropout=0.10,
        target_modules=target_modules,
        r=32,
        bias="none",
        task_type="SEQ_CLS",
        inference_mode=False,
)

model.base_model = get_peft_model(model.base_model, peft_config)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.__version__)
print(torch.cuda.is_available())
# Preprocess Structural, Temporal, and Prompt embeddings
def preprocess_tokens(X_profile, adj_matrix, prompt_text, d_LLM, m, model, tokenizer, device):
    # Structural Embedding
    def compute_structural_embedding(adj_matrix, embed_dim):
        import networkx as nx
        from scipy.linalg import eigh
        G = nx.from_numpy_array(adj_matrix)
        L = nx.normalized_laplacian_matrix(G).toarray()
        _, eigenvectors = eigh(L)
        return eigenvectors[:, :embed_dim]

    with torch.no_grad():
        n_eigen = 6
        raw_structural = compute_structural_embedding(adj_matrix, n_eigen)
        raw_structural = torch.tensor(raw_structural, dtype=torch.float32).to(device)  # Shape: (n, d_LLM)

    structural_tokens = structural_tokenizer(raw_structural) # structural embedding is fed to the structural tokenizer defined earlier

    # Temporal Embedding
    class TemporalEmbedding(nn.Module):
        def __init__(self, input_dim, embed_dim):
            super().__init__()
            self.W_Q = nn.Linear(input_dim, embed_dim)
            self.W_K = nn.Linear(input_dim, embed_dim)
            self.W_V = nn.Linear(input_dim, embed_dim)

        def forward(self, x):
            Q = self.W_Q(x)
            K = self.W_K(x)
            V = self.W_V(x)
            scores = torch.softmax(torch.bmm(Q.unsqueeze(0), K.unsqueeze(0).transpose(1, 2)) / np.sqrt(K.shape[-1]), dim=-1)
            return torch.bmm(scores, V.unsqueeze(0)).squeeze(0)

    temporal_embedder = TemporalEmbedding(input_dim=m, embed_dim=d_LLM).to(device)
    with torch.no_grad():
        temporal_input = X_profile.to(device)
        temporal_output = temporal_embedder(temporal_input)  # Shape: (T, d_LLM)

    temporal_tokens = temporal_tokenizer(temporal_output) # temporal embedding is fed to the temporal tokenizer defined earlier

    # Prompt Tokens
    prompt_ids = tokenizer(prompt_text, return_tensors='pt', padding=True, truncation=True).input_ids.to(device) # pretrained tokenizer
    prompt_embeds = model.base_model.get_input_embeddings()(prompt_ids).squeeze(0)

    return structural_tokens, temporal_tokens, prompt_embeds


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, inputs_embeds, labels):
        self.inputs_embeds = inputs_embeds  # shape: (batch, seq_len, hidden)
        self.labels = labels  # shape: (batch, 6)

    def __len__(self):
        return len(self.inputs_embeds)

    def __getitem__(self, idx):
        return {
            "inputs_embeds": self.inputs_embeds[idx],
            "labels": self.labels[idx]
        }

class CustomTrainer(Trainer):
    def get_train_dataloader(self):
        train_dataloader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            shuffle=True,
            collate_fn=self.data_collator,
            pin_memory=False,
        )
        return train_dataloader

    def get_eval_dataloader(self, eval_dataset=None):
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        eval_dataloader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator,
            pin_memory=False,
        )
        return eval_dataloader

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        inputs_embeds = inputs["inputs_embeds"]
        labels = inputs["labels"]
        outputs = model(inputs_embeds=inputs_embeds, labels=labels)
        loss = outputs["loss"]
        return (loss, outputs) if return_outputs else loss

def custom_collator(features):
    inputs_embeds = torch.stack([f["inputs_embeds"] for f in features])
    inputs_embeds = inputs_embeds.clone().detach().requires_grad_(True)
    labels = torch.stack([f["labels"] for f in features])
    return {"inputs_embeds": inputs_embeds, "labels": labels.to(device)}

features = ['Adiol', 'Bdiol', 'A', 'Etio', 'E', 'T']
adj_matrix = np.array([
    [1, 0, 0, 0, 0, 1],
    [0, 1, 1, 1, 0, 1],
    [1, 0, 1, 0, 0, 0],
    [0, 1, 0, 1, 0, 0],
    [0, 0, 1, 0, 1, 0],
    [1, 1, 1, 0, 0, 1]
])
n = adj_matrix.shape[0]  # Number of nodes

# We load data in chunks because memory cannot load all profiles at once if we finetune on a large number of profiles
def train_on_chunk(chunk, eval_chunk, current_epoch, normal_sample):
    chunk = chunk[['ID'] + features]
    grouped = chunk.groupby('ID')
    profiles = [group.drop('ID', axis=1).to_numpy(dtype=np.float32) for _, group in grouped]
    n_max = max(profile.shape[0] for profile in profiles)
    padded_profiles = np.array([
        np.pad(profile, ((0, n_max - profile.shape[0]), (0, 0)), mode='constant', constant_values=0)
        for profile in profiles
    ])
    mask = (padded_profiles != 0).any(axis=-1).astype(float)
    X = torch.tensor(padded_profiles, dtype=torch.float32)
    mask = torch.tensor(mask, dtype=torch.float32)
    gc.collect()
    torch.cuda.empty_cache()
    combined_tokens_list = []

    if normal_sample: # first train on normal samples using this prompt
        prompt_text = "Analyse the given longitudinal profile of male athletes, consisting of multiple steroid metabolites measured across different time points. All samples are normal. Learn the sequence."

    else: # then we train on samples with anomalies using this prompt
        prompt_text = "Analyse the given longitudinal profile of male athletes, consisting of multiple steroid metabolites measured across different time points. Learn the sequence but note that the last sample is anomaly."

    for profile, profile_mask in zip(X, mask):
        unmasked_profile = profile[profile_mask.bool()]
        structural_tokens, temporal_tokens, prompt_tokens = preprocess_tokens(
            unmasked_profile, adj_matrix, prompt_text, d_LLM, len(features), model, tokenizer, device
        )

        structural_tokens = structural_tokens.to(device)
        temporal_tokens = temporal_tokens.to(device)
        prompt_tokens = prompt_tokens.to(device)

        combined_tokens = torch.cat([structural_tokens, temporal_tokens, prompt_tokens], dim=0)
        combined_tokens_list.append(combined_tokens)
    max_len = max(t.shape[0] for t in combined_tokens_list)
    for i in range(len(combined_tokens_list)):
        seq_len = combined_tokens_list[i].shape[0]
        if seq_len < max_len:
            pad_len = max_len - seq_len
            pad_tensor = torch.zeros((pad_len, d_LLM), dtype=combined_tokens_list[i].dtype, device=combined_tokens_list[i].device)
            combined_tokens_list[i] = torch.cat([combined_tokens_list[i], pad_tensor], dim=0)

    combined_tokens = torch.stack(combined_tokens_list).to(torch.float16).to(device)
    print("Combined Token Dimensions:", combined_tokens.shape)
    gc.collect()
    torch.cuda.empty_cache()
    # Regression target: last row of each profile
    targets = [profile[profile_mask.bool()][-1] for profile, profile_mask in zip(X, mask)]
    labels = torch.stack(targets).to(device)  # shape: (batch, 6)

    combined_tokens = combined_tokens.to(torch.float16)

    train_dataset = CustomDataset(combined_tokens, labels)

    print("Inputs Embeds Shape:", combined_tokens.shape)  # Should be (batch_size, seq_len, embedding_dim)
    print("Labels Shape:", labels.shape)

    trainer.train_dataset = train_dataset
    train_results = trainer.train()
    print(f"Training Results: {train_results.metrics}")
    del chunk
    del padded_profiles, mask
    del structural_tokens, temporal_tokens, prompt_tokens
    del combined_tokens, labels, train_dataset
    gc.collect()
    torch.cuda.empty_cache()

# Do the same for evaluation
def evaluate_on_chunk(chunk, trainer, normal_sample):
    with torch.no_grad():
        chunk = chunk[['ID'] + features]
        grouped = chunk.groupby('ID') # samples sharing same ID belong to same profile
        profiles = [group.drop('ID', axis=1).to_numpy(dtype=np.float32) for _, group in grouped]
        n_max = max(profile.shape[0] for profile in profiles)
        padded_profiles = np.array([
            np.pad(profile, ((0, n_max - profile.shape[0]), (0, 0)), mode='constant', constant_values=0)
            for profile in profiles
        ])
        mask = (padded_profiles != 0).any(axis=-1).astype(float)
        # Convert to tensors
        X = torch.tensor(padded_profiles, dtype=torch.float32)
        mask = torch.tensor(mask, dtype=torch.float32)

        combined_tokens_list = []
        if normal_sample:
            prompt_text = "Analyse the given longitudinal profile of male athletes, consisting of multiple steroid metabolites measured across different time points. All samples are normal. Confirm consistency."

        else:
            prompt_text = "Analyse the given longitudinal profile of male athletes, consisting of multiple steroid metabolites measured across different time points. The last sample is an anomaly. Identify and learn why is it anomaly."

        for profile, profile_mask in zip(X, mask):
            unmasked_profile = profile[profile_mask.bool()]
            structural_tokens, temporal_tokens, prompt_tokens = preprocess_tokens(
                unmasked_profile, adj_matrix, prompt_text, d_LLM, len(features), model, tokenizer, device
            )
            structural_tokens = structural_tokens.detach().to(device)
            temporal_tokens = temporal_tokens.detach().to(device)
            prompt_tokens = prompt_tokens.to(device)
            combined_tokens = torch.cat([structural_tokens, temporal_tokens, prompt_tokens], dim=0)
            combined_tokens_list.append(combined_tokens)
        max_len = max(t.shape[0] for t in combined_tokens_list)
        for i in range(len(combined_tokens_list)):
            seq_len = combined_tokens_list[i].shape[0]
            if seq_len < max_len:
                pad_len = max_len - seq_len
                pad_tensor = torch.zeros((pad_len, d_LLM), dtype=combined_tokens_list[i].dtype, device=combined_tokens_list[i].device)
                combined_tokens_list[i] = torch.cat([combined_tokens_list[i], pad_tensor], dim=0)
        combined_tokens = torch.stack(combined_tokens_list).to(torch.float16).to(device)
        print("eval Combined Token Dimensions:", combined_tokens.shape)
        gc.collect()
        torch.cuda.empty_cache()
        # Regression target: last row of each profile
        targets = [profile[profile_mask.bool()][-1] for profile, profile_mask in zip(X, mask)]
        labels = torch.stack(targets).to(device)  # shape: (batch, 6)

        combined_tokens = combined_tokens.to(torch.float16)
        eval_dataset = CustomDataset(combined_tokens, labels)

        trainer.eval_dataset = eval_dataset
        results = trainer.evaluate()
        print(f"Evaluation Results: {results}")
        outputs = trainer.model(inputs_embeds=combined_tokens)
        preds = outputs["predictions"]
        rmse = torch.sqrt(torch.mean((preds - labels) ** 2)).item()
        mae = torch.mean(torch.abs(preds - labels)).item()
        mape = torch.mean(torch.abs((preds - labels) / (labels + 1e-8))).item() * 100

        print(f"Eval RMSE: {rmse:.4f}, MAE: {mae:.4f}, MAPE: {mape:.2f}%")

    del chunk
    del padded_profiles, mask
    del structural_tokens, temporal_tokens, prompt_tokens
    del combined_tokens, labels, eval_dataset
    gc.collect()
    torch.cuda.empty_cache()
    return results


args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    save_strategy="epoch",
    logging_steps=10,
    logging_dir='./logs',
    learning_rate=2e-5,
    warmup_steps=500,
    save_total_limit=2,
    remove_unused_columns=False,
    fp16=True,
    report_to=[],
)

trainer = CustomTrainer(
    model=model,
    args=args,
    tokenizer=None,
    data_collator=custom_collator,
)

# Main training loop
NUM_EPOCHS = 10
global_epoch = 0

for epoch in range(NUM_EPOCHS):
    print(f"Starting epoch {epoch + 1}/{NUM_EPOCHS}...")
    header= pd.read_csv(file_path, nrows=0).columns
    # First, train on the normal profiles
    train_chunks = pd.read_csv(file_path, chunksize=15, nrows=15) # dummy numbers
    eval_chunks = pd.read_csv(eval_file_path, chunksize=10, nrows=10) # dummy numbers
    
    print("Training on normal profiles...")
    for train_chunk, eval_chunk in zip(train_chunks, eval_chunks):
        print(f"Processing normal chunk for epoch {epoch + 1}...")
        train_on_chunk(train_chunk, trainer, current_epoch=global_epoch, normal_sample=True)
        evaluate_on_chunk(eval_chunk, trainer, normal_sample=True)
        del train_chunk, eval_chunk
        gc.collect()
        torch.cuda.empty_cache()

    # then train on the anomaly profiles
    train_chunks = pd.read_csv(file_path, chunksize=15, nrows=15) # dummy numbers
    eval_chunks = pd.read_csv(eval_file_path, chunksize=10, nrows=10) # dummy numbers
    
    print("Training on anomaly profiles...")
    for train_chunk, eval_chunk in zip(train_chunks, eval_chunks):
        print(f"Processing anomaly chunk for epoch {epoch + 1}...")
        train_on_chunk(train_chunk, trainer, current_epoch=global_epoch, normal_sample=False)
        evaluate_on_chunk(eval_chunk, trainer, normal_sample=False)
        del train_chunk, eval_chunk
        gc.collect()
        torch.cuda.empty_cache()

    del train_chunks, eval_chunks
    gc.collect()
    torch.cuda.empty_cache()
    global_epoch += 1

    print(f"Finished epoch {epoch + 1}/{NUM_EPOCHS}")
    gc.collect()
    torch.cuda.empty_cache()

torch.save(model.state_dict(), f"{output_dir}/pytorch_model.bin")
tokenizer.save_pretrained(output_dir)
