import os
import csv
import torch
import torch.nn as nn
import argparse
import numpy as np
from safetensors.torch import load_file
import torch
import sys
sys.stdout.flush()
import functools
print = functools.partial(print, flush=True)
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
import torch.nn.functional as F

from typing import Any, Dict, List, Optional, Union
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from transformers import (
    AutoTokenizer,
    AutoModelForMultipleChoice,
    Trainer,
    TrainingArguments,
    TrainerCallback,
    AutoModelForSequenceClassification,
    AutoModelForCausalLM,
    PreTrainedTokenizerBase,
    EarlyStoppingCallback
)

import transformers
print(transformers.__version__)
from dataclasses import dataclass

# ---- Memory Tracking Callback ----
class ManualEvalAccuracyTracker(TrainerCallback):
    def __init__(self, trainer, eval_dataset, csv_path="epoch_log.csv"):
        self.trainer = trainer
        self.eval_dataset = eval_dataset
        self.csv_path = csv_path
        self.epoch = 0
        with open(self.csv_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["epoch", "accuracy", "peak_memory_MB", "intermediate_memory_MB", "trainable_params"])

    def on_epoch_begin(self, args, state, control, **kwargs):
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()

    def on_epoch_end(self, args, state, control, **kwargs):
        self.epoch += 1

        torch.cuda.synchronize()

        # Memory currently allocated after training 
        intermediate_memory = torch.cuda.memory_allocated() / (1024 ** 2)

        # Evaluation (can increase peak)
        eval_metrics = self.trainer.evaluate(self.eval_dataset)
        acc = eval_metrics.get("eval_accuracy", -1)

        torch.cuda.synchronize()

        # Peak memory from train + eval
        peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 2)

        with open(self.csv_path, "a", newline="") as f:
            writer = csv.writer(f)
            if self.epoch == 1 and 'model' in kwargs:
                num_params = sum(p.numel() for p in kwargs['model'].parameters() if p.requires_grad)
            else:
                num_params = ""
            writer.writerow([self.epoch, acc, peak_memory, intermediate_memory, num_params])

        torch.cuda.reset_peak_memory_stats()



# ---- Adapter Layers ----
class KronLoRALinear(nn.Module):
    def __init__(self, original_linear, rank_B=8, dim_row_A=2, dim_col_A=16, alpha=32.0, dropout_rate=0.1):
        super().__init__()
        self.original_linear = original_linear  # Original linear layer (frozen)
        self.in_features = original_linear.in_features
        self.out_features = original_linear.out_features
        self.dB1 = self.in_features //dim_row_A
        self.dB2 = self.out_features // dim_col_A
        self.dA1 = dim_row_A
        self.dA2 = dim_col_A
        self.alpha = alpha  # Scaling factor for the low-rank update
        self.dropout_rate = dropout_rate
        self.rank_B = rank_B

        # Freeze the original weights
        for param in self.original_linear.parameters():
            param.requires_grad = False

        # Initialize low-rank matrices A and B
        self.kronlora_A = nn.Linear(self.dA2, self.dA1, bias=False)
        self.kronlora_B1 = nn.Parameter(torch.randn(self.dB2, rank_B) * 0)
        self.kronlora_B2 = nn.Parameter(torch.randn(rank_B, self.dB1) * (1/rank_B))

        # Initialize its weight with N(0, 0.25²)
        nn.init.normal_(self.kronlora_A.weight, mean=1.0, std=0.25)

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        # Original output
        original_out = self.original_linear(x)

        #new_shape = x.shape[:-1] + (self.dB1, self.dA1)
        x_reshaped = x.view(*x.shape[:-1], self.dB1, self.dA1)  

        B2x = torch.matmul(self.kronlora_B2, x_reshaped)  

        # Step 2: (B2x) @ A^T -> [..., rank_B, 2]
        B2xA = self.kronlora_A(B2x)  
        # Step 3: B1 @ (B2xA) -> [..., 2048, 2]
        B1B2xA = torch.matmul(self.kronlora_B1, B2xA)  
     
        # Flatten the last two dims back
        kron_update = B1B2xA.view(*x.shape[:-1], -1)
        kron_update = self.dropout(kron_update)
        # Dropout and scale
        return original_out + (self.alpha / self.rank_B) * kron_update

class LoRALinear(nn.Module):
    def __init__(self, original_linear, r=8, alpha=32.0, dropout_rate=0.1):
        super().__init__()
        self.original_linear = original_linear
        self.r = r
        self.alpha = alpha
        self.dropout = nn.Dropout(dropout_rate)

        # Freeze the original weights
        for p in self.original_linear.parameters():
            p.requires_grad = False

        
        self.lora_B = nn.Linear(original_linear.in_features, r, bias=False)
        self.lora_A = nn.Linear(r, original_linear.out_features, bias=False)

        # Initialize down and up same to original LoRA
        nn.init.normal_(self.lora_B.weight, std=1/r)
        nn.init.zeros_(self.lora_A.weight)

    def forward(self, x):
        lora_out = self.lora_A(self.lora_B(x))  # (B, *, out_features)
        scaled = self.dropout(self.alpha / self.r * lora_out)
        return self.original_linear(x) + scaled

# ---- LoRA Class uses Mistral ----
class LoRAMistralForMultipleChoice(nn.Module):
    def __init__(self, base_model_name, r=8, alpha=32.0, checkpoint_path=None):
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            token=os.environ.get("HF_TOKEN", None),
            trust_remote_code=True
        )
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=os.environ.get("HF_TOKEN", None))
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "left"
        # Replace linear layers with LoRA and freeze base model
        self._replace_linear(self.model, r, alpha)
        if checkpoint_path:
            state_dict = load_file(checkpoint_path)
            self.load_state_dict(state_dict)

        # Freeze all non-LoRA/non-output layers
        self._freeze_non_lora_layers()

    def _replace_linear(self, module, r, alpha):
        for name, child in module.named_children():
            if isinstance(child, nn.Linear):
                # Replace with LoRA and keep it trainable
                setattr(module, name, LoRALinear(child, r, alpha))
            else:
                self._replace_linear(child, r, alpha)

    def _freeze_non_lora_layers(self):
        # Freeze everything first
        for param in self.model.parameters():
            param.requires_grad = False

        # Unfreeze LoRA layers (they have 'lora' in names)
        for name, param in self.model.named_parameters():
            if 'lora' in name:
                param.requires_grad = True


    def gradient_checkpointing_enable(self, **kwargs):
        if hasattr(self.model, "gradient_checkpointing_enable"):
            return self.model.gradient_checkpointing_enable(**kwargs)


    def forward(self, input_ids, attention_mask=None, labels=None):
        # input_ids: [batch_size, num_choices, seq_len]
        batch_size, num_choices, seq_len = input_ids.shape

        # Flatten for processing
        input_ids = input_ids.view(-1, seq_len)
        if attention_mask is not None:
            attention_mask = attention_mask.view(-1, seq_len)

        # Get model outputs
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # [batch_size * num_choices, seq_len, vocab_size]

        # Compute log-likelihoods
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        shift_mask = attention_mask[:, 1:] if attention_mask is not None else None

        loss_fct = nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        loss = loss.view(shift_labels.size())

        if shift_mask is not None:
            loss = loss * shift_mask

        per_choice_nll = loss.sum(dim=1)  # [batch_size * num_choices]
        per_choice_nll = per_choice_nll.view(batch_size, num_choices)
        logits = -per_choice_nll  # Higher log-probability = better

        if labels is not None:
            ce_loss_fct = nn.CrossEntropyLoss()
            loss = ce_loss_fct(logits, labels)
            return {"loss": loss, "logits": logits}
        else:
            return {"logits": logits}


            # Sum negative log-likelihood across tokens for each choice
            per_choice_nll = loss.sum(dim=1)  # [batch_size * num_choices]
            per_choice_nll = per_choice_nll.view(batch_size, num_choices)

            # Lower NLL = better, so we return -NLL as logits for classification
            return -per_choice_nll

# ---- Kron-LoRA Class uses Mistral ----
class KronMistralForMultipleChoice(nn.Module):
    def __init__(self, base_model_name, rank_B=8, dim_row_A=2, dim_col_A=16, alpha=32.0, checkpoint_path=None):
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            token=os.environ.get("HF_TOKEN", None),
            trust_remote_code=True
        )
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=os.environ.get("HF_TOKEN", None))
    

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "left"
        # Replace linear layers with KronLoRA
        self._replace_kronlinear(self.model, rank_B, dim_row_A, dim_col_A, alpha,exclude_modules=(self.model.lm_head,))
        
        self.model.lm_head = KronLoRALinear(self.model.lm_head, rank_B=8, dim_row_A=1, dim_col_A=8)
        if checkpoint_path:
            self.load_state_dict(torch.load(checkpoint_path))

        self._freeze_non_kronlora_layers()
        # Verify no original weights are trainable
        self._verify_frozen_weights()

    def _replace_kronlinear(self, module, rank_B, dim_row_A, dim_col_A, alpha, exclude_modules=()):
        for name, child in module.named_children():
            if isinstance(child, nn.Linear):
                if child not in exclude_modules and not isinstance(child, KronLoRALinear):
                    setattr(module, name, KronLoRALinear(
                        child, rank_B, dim_row_A, dim_col_A, alpha
                    ))
            else:
                self._replace_kronlinear(child, rank_B, dim_row_A, dim_col_A, alpha, exclude_modules)
    def _freeze_non_kronlora_layers(self):
        # Freeze everything first
        for param in self.model.parameters():
            param.requires_grad = False

        # Unfreeze Kron-LoRA layers (they have 'kronlora' in names)
        for name, param in self.model.named_parameters():
            if 'kronlora' in name:
                param.requires_grad = True

    def _verify_frozen_weights(self):
        for name, param in self.model.named_parameters():
            if "original_linear" in name:
                assert not param.requires_grad, f"Original weight {name} is still trainable!"

    def forward(self, input_ids, attention_mask=None, labels=None):
        # input_ids: [batch_size, num_choices, seq_len]
        batch_size, num_choices, seq_len = input_ids.shape

        # Flatten for processing
        input_ids = input_ids.view(-1, seq_len)
        if attention_mask is not None:
            attention_mask = attention_mask.view(-1, seq_len)

        # Get model outputs
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # [batch_size * num_choices, seq_len, vocab_size]

        # Compute log-likelihoods
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        shift_mask = attention_mask[:, 1:] if attention_mask is not None else None

        loss_fct = nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        loss = loss.view(shift_labels.size())

        if shift_mask is not None:
            loss = loss * shift_mask

        per_choice_nll = loss.sum(dim=1)  # [batch_size * num_choices]
        per_choice_nll = per_choice_nll.view(batch_size, num_choices)
        logits = -per_choice_nll  # Higher log-probability = better

        if labels is not None:
            ce_loss_fct = nn.CrossEntropyLoss()
            loss = ce_loss_fct(logits, labels)
            return {"loss": loss, "logits": logits}
        return {"logits": logits}

    def gradient_checkpointing_enable(self, **kwargs):
        if hasattr(self.model, "gradient_checkpointing_enable"):
            return self.model.gradient_checkpointing_enable(**kwargs)

# ---- DataCollator for HellaSwag ----
class DataCollatorForMultipleChoice:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, features):
        # First check if we have labels
        has_labels = "labels" in features[0]
        labels = torch.tensor([f.pop("labels") for f in features], dtype=torch.long) if has_labels else None

        # Get number of choices (should be 4 for HellaSwag)
        num_choices = len(features[0]["input_ids"])
        batch_size = len(features)

        # Flatten the features
        flattened_features = []
        for feature in features:
            for i in range(num_choices):
                flattened_features.append({
                    "input_ids": feature["input_ids"][i],
                    "attention_mask": feature["attention_mask"][i]
                })

        # Pad the flattened features
        batch = self.tokenizer.pad(
            flattened_features,
            padding=True,
            return_tensors="pt"
        )

        # Reshape back to [batch_size, num_choices, seq_len]
        for k in batch:
            batch[k] = batch[k].view(batch_size, num_choices, -1)

        if has_labels:
            batch["labels"] = labels

        return batch

# ---- Preprocessing for HellaSwag ----
def preprocess_function(examples, tokenizer):
    first_sentences = [[ctx] * 4 for ctx in examples["ctx_a"]]
    second_sentences = [[ending for ending in endings] for endings in examples["endings"]]
    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized = tokenizer(
        first_sentences,
        second_sentences,
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="np"  # Return numpy arrays instead of PyTorch tensors
    )

    # Reshape to [num_examples, num_choices, seq_len]
    result = {
        "input_ids": tokenized["input_ids"].reshape(-1, 4, 128),
        "attention_mask": tokenized["attention_mask"].reshape(-1, 4, 128)
    }

    if "label" in examples:
        result["labels"] = np.array([int(lbl) if lbl != '' else 0 for lbl in examples["label"]])

    return result


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    return {"accuracy": accuracy_score(labels, preds)}

# ---- Main Entrypoint ----
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--adapter_type", type=str, choices=["kron", "lora"], default="kron")
    parser.add_argument("--model_name_or_path", type=str, default="mistralai/Mistral-7B-v0.1")
    parser.add_argument("--lora_r", type=int, default=8)
    parser.add_argument("--alpha", type=float, default=32.0)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--output_dir", type=str, default="./output")
    parser.add_argument("--checkpoint_path", type=str, default=None)
    parser.add_argument("--epoch_num", type=int, default=16)
    parser.add_argument("--HF_TOKEN", type=str, default=None)
    parser.add_argument("--train_batch_size", type=int, default=16)
    parser.add_argument("--eval_batch_size", type=int, default=8)
    parser.add_argument("--WANDB_DISABLED", type=str, default="true")
    parser.add_argument("--eval_strategy", type=str, default="epoch")
    parser.add_argument("--save_strategy", type=str, default="no")
    parser.add_argument("--load_best_model_at_end", type=bool, default=False)
    parser.add_argument("--save_total_limit", type=int, default=2)
    args = parser.parse_args()

    os.environ["WANDB_DISABLED"] = args.WANDB_DISABLED
    os.environ["HF_TOKEN"] = "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"#args.HF_TOKEN

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, token=os.environ["HF_TOKEN"], trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token


    dataset = load_dataset("hellaswag", trust_remote_code=True)
    dataset = dataset.map(
        lambda x: preprocess_function(x, tokenizer),
        batched=True,
        remove_columns=dataset["train"].column_names  # Remove original columns
    )

    # Set format to return numpy arrays
    dataset.set_format(type="numpy", columns=["input_ids", "attention_mask", "labels"])

    train_dataset = dataset["train"]#.train_test_split(test_size=0.001, seed=1)["test"]
    val_dataset = dataset["validation"]#.train_test_split(test_size=0.001, seed=1)["test"]

    if args.adapter_type == "kron":
        print("Using Kron-LoRA")
        model = KronMistralForMultipleChoice(args.model_name_or_path, alpha=args.alpha,checkpoint_path=args.checkpoint_path)
    else:
        print("Using LoRA")
        model = LoRAMistralForMultipleChoice(args.model_name_or_path, r=args.lora_r, alpha=args.alpha,checkpoint_path=args.checkpoint_path)

    training_args = TrainingArguments(
        output_dir=args.output_dir,
        eval_strategy=args.eval_strategy, #if see error here, change the name to evaluation_strategy.
        save_strategy=args.save_strategy, #currently the checkpoint is unsaved, can change to either epoch or steps
        save_total_limit=args.save_total_limit,
        per_device_train_batch_size=args.train_batch_size,
        per_device_eval_batch_size=args.eval_batch_size,
        num_train_epochs=args.epoch_num,
        learning_rate=args.lr,
        fp16=True,
        logging_steps=10,
        load_best_model_at_end=args.load_best_model_at_end,
        gradient_checkpointing=True,
        metric_for_best_model="accuracy",
        greater_is_better=True,
        disable_tqdm=False,  # Ensure progress bar shows
        logging_dir="./logs",  # Explicit logging directory
        report_to="none",  # Disable WandB if not needed
        log_level="info",  # More verbose logging
        #logging_first_step=True  # Show first step immediately
)



    print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    
    data_collator = DataCollatorForMultipleChoice(tokenizer)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[
            EarlyStoppingCallback(early_stopping_patience=4)  # Stops if no improvement in 4 evals
        ]
    )
    tracker = ManualEvalAccuracyTracker(trainer=trainer, eval_dataset=val_dataset)
    trainer.add_callback(tracker)


    trainer.train()


if __name__ == "__main__":
    import sys

    main()


