import argparse
import os
import random
import numpy as np
import torch
import ast
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments
)
from torch import nn

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def collate_fn(batch, tokenizer):
    inputs = []
    labels = []

    for i, item in enumerate(batch):
        story = item.get("story", "")
        query = item.get("query")
        target_text = item.get("target_text", "")

        if query is not None:
            if isinstance(query, str) and query.strip().startswith("("):
                try:
                    parsed = ast.literal_eval(query)
                    if isinstance(parsed, (list, tuple)) and len(parsed) >= 2:
                        query_str = f"What is the relationship between {parsed[0]} and {parsed[1]}? Answer:"
                    else:
                        query_str = f"What is the relationship between {query}? Answer:"
                except Exception:
                    query_str = f"What is the relationship between {query}? Answer:"
            else:
                query_str = f"What is the relationship between {query}? Answer:"
        else:
            query_str = "What is the relationship? Answer:"

        prompt = f"Story: {story}\nQuery: {query_str}"
        combined = prompt + " " + target_text

        tokenized_combined = tokenizer(combined, add_special_tokens=True, return_tensors=None)
        tokenized_prompt = tokenizer(prompt, add_special_tokens=True, return_tensors=None)

        prompt_len = len(tokenized_prompt["input_ids"])
        input_ids = tokenized_combined["input_ids"]
        sample_labels = [-100] * prompt_len + input_ids[prompt_len:]

        inputs.append(input_ids)
        labels.append(sample_labels)

    batch_enc = tokenizer.pad(
        {"input_ids": inputs},
        padding=True,
        return_attention_mask=True,
        return_tensors="pt",
    )

    max_len = batch_enc["input_ids"].size(1)
    labels_padded = [l + [-100] * (max_len - len(l)) for l in labels]
    batch_enc["labels"] = torch.tensor(labels_padded, dtype=torch.long)

    return batch_enc

class GPT2WithResidualBeta(nn.Module):
    def __init__(self, base_model: nn.Module):
        super().__init__()
        self.base_model = base_model

        for p in self.base_model.parameters():
            p.requires_grad = False

        self.num_layers = len(self.base_model.transformer.h)
        self.betas = nn.Parameter(torch.ones(self.num_layers))

        self._block_inputs = [None] * self.num_layers

        self.hooks = []
        self._register_hooks()

    def _register_hooks(self):
        def make_pre_hook(layer_idx: int):
            def pre_hook(module, inputs):
                self._block_inputs[layer_idx] = inputs[0]
            return pre_hook

        def make_post_hook(layer_idx: int):
            def post_hook(module, inputs, output):

                x = self._block_inputs[layer_idx]

                if isinstance(output, tuple):
                    y = output[0]
                    extras = output[1:]
                else:
                    y = output
                    extras = None

                beta = self.betas[layer_idx].to(dtype=y.dtype, device=y.device)

                y_new = x + beta * (y - x)

                if extras is None:
                    return y_new
                return (y_new, *extras)
            return post_hook

        for i, block in enumerate(self.base_model.transformer.h):
            h1 = block.register_forward_pre_hook(make_pre_hook(i))
            h2 = block.register_forward_hook(make_post_hook(i))
            self.hooks.extend([h1, h2])

    def remove_hooks(self):
        for h in self.hooks:
            try:
                h.remove()
            except Exception:
                pass
        self.hooks = []

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        return self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            **kwargs
        )

    def print_beta_values(self):
        with torch.no_grad():
            betas = self.betas.detach().cpu().numpy()
        print("Beta values:")
        for i, b in enumerate(betas):
            print(f"Layer {i}: {b:.4f}")
        print("-" * 30)


def process_test_data(test_dataset):
    processed_examples = []
    
    for example in test_dataset:
        story = example.get("story", "")
        target_text = example.get("target_text", "")
        query = example.get("query", None)

        query_str = "What is the relationship? Answer:"
        if query is not None:
            if isinstance(query, str) and query.strip().startswith("("):
                try:
                    parsed_query = ast.literal_eval(query)
                    if isinstance(parsed_query, (list, tuple)) and len(parsed_query) >= 2:
                        query_str = f"What is the relationship between {parsed_query[0]} and {parsed_query[1]}? Answer:"
                except Exception as e:
                    print(f"Error parsing query '{query}': {e}")
                    query_str = f"What is the relationship between {query}? Answer:"
            else:
                query_str = f"What is the relationship between {query}? Answer:"
        
        prompt = f"Story: {story}\nQuery: {query_str}"
        combined = prompt + " " + target_text
        
        processed_examples.append({
            "prompt": prompt,
            "combined": combined,
            "target_text": target_text
        })
    
    return processed_examples

def evaluate_model(model, test_dataset, tokenizer, device):
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)

    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    print(f"Test dataset length: {len(test_dataset)}")
    
    processed_examples = process_test_data(test_dataset)
    print(f"Processed {len(processed_examples)} test examples")
    
    if processed_examples:
        print("Example processed test item:")
        print(f"Prompt: {processed_examples[0]['prompt'][:100]}...")
        print(f"Target: {processed_examples[0]['target_text']}")
    
    with torch.no_grad():
        for batch_idx in range(0, len(processed_examples), 16):
            batch = processed_examples[batch_idx:batch_idx + 16]

            inputs = []
            labels = []
            
            for item in batch:
                prompt = item["prompt"]
                combined = item["combined"]
                
                tokenized_combined = tokenizer(combined, add_special_tokens=True, return_tensors=None)
                tokenized_prompt = tokenizer(prompt, add_special_tokens=True, return_tensors=None)
                
                prompt_len = len(tokenized_prompt["input_ids"])
                input_ids = tokenized_combined["input_ids"]
                sample_labels = [-100] * prompt_len + input_ids[prompt_len:]
                
                inputs.append(input_ids)
                labels.append(sample_labels)

            batch_enc = tokenizer.pad(
                {"input_ids": inputs},
                padding=True,
                return_attention_mask=True,
                return_tensors="pt",
            )
            
            max_len = batch_enc["input_ids"].size(1)
            labels_padded = [l + [-100] * (max_len - len(l)) for l in labels]
            batch_enc["labels"] = torch.tensor(labels_padded, dtype=torch.long)

            input_ids = batch_enc["input_ids"].to(device)
            attention_mask = batch_enc["attention_mask"].to(device)
            labels = batch_enc["labels"].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            batch_loss = outputs.loss.item()
            total_loss += batch_loss * len(batch)

            logits = outputs.logits 
            first_label_positions = (labels != -100).float().argmax(dim=1) 
            
            for i in range(logits.size(0)):
                pos = first_label_positions[i].item()
                if pos > 0 and pos < labels.size(1) and labels[i, pos] != -100:
                    pred_token = logits[i, pos - 1].argmax(dim=-1) 
                    true_token = labels[i, pos]
                    
                    if pred_token == true_token:
                        correct_predictions += 1
                    total_samples += 1

            if batch_idx == 0:
                print(f"First batch example:")
                example_idx = 0
                example_prompt = tokenizer.decode(input_ids[example_idx][:first_label_positions[example_idx].item()])
                example_true_token = tokenizer.decode(labels[example_idx, first_label_positions[example_idx].item()].unsqueeze(0))
                example_pred_token = tokenizer.decode(logits[example_idx, first_label_positions[example_idx].item() - 1].argmax(dim=-1).unsqueeze(0))
                print(f"Prompt:\n{example_prompt}")
                print(f"True next token: '{example_true_token}'")
                print(f"Predicted next token: '{example_pred_token}'")
                print("-" * 30)
    
    avg_loss = total_loss / len(processed_examples) if processed_examples else 0
    accuracy = correct_predictions / total_samples if total_samples > 0 else 0
    
    print(f"Total correct: {correct_predictions}, Total samples: {total_samples}")
    print(f"Final Accuracy: {accuracy:.4f}")
    
    return avg_loss, accuracy

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--checkpoint_dir", type=str, required=True)
    parser.add_argument("--output_dir", type=str, default="path/to/your/folder")
    args = parser.parse_args()

    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    os.makedirs(args.output_dir, exist_ok=True)

    config_path = os.path.join(args.output_dir, "path/to/your/file")
    with open(config_path, "w") as f:
        for arg in vars(args):
            f.write(f"{arg}: {getattr(args, arg)}\n")

    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    dataset = load_dataset("CLUTRR/v1", "gen_train23_test2to10")

    def task_filter(example):
        task_name = example.get("task_name", "")
        return not (task_name.endswith("1.2") or task_name.endswith("1.3"))

    train_ds = dataset["test"].filter(task_filter)
    # train_ds = dataset["train"]
    test_dataset = dataset["test"].filter(task_filter)
    # test_dataset = dataset["train"]

    print(f"Loading model from checkpoint: {args.checkpoint_dir}")
    base_model = AutoModelForCausalLM.from_pretrained(args.checkpoint_dir)
    base_model = base_model.to(device)
    
    print("Evaluating model on test dataset...")
    test_loss, test_accuracy = evaluate_model(base_model, test_dataset, tokenizer, device)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")

    with open(os.path.join(args.output_dir, "test_metrics_before_new.txt"), "w") as f:
        f.write(f"Test Loss: {test_loss:.4f}\n")
        f.write(f"Test Accuracy: {test_accuracy:.4f}\n")

    model = GPT2WithResidualBeta(base_model)
    model.to(device)

    print("Initial beta values:")
    model.print_beta_values()

    training_args = TrainingArguments(
        seed=args.seed,
        data_seed=args.seed,
        output_dir=args.output_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        learning_rate=args.lr,
        weight_decay=0.01,
        remove_unused_columns=False,
        report_to=["none"],
        save_strategy="no"
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        data_collator=lambda b: collate_fn(b, tokenizer),
    )

    trainer.train()
    
    print("Final beta values after training:")
    model.print_beta_values()
    
    torch.save({
        "betas": model.betas.detach().cpu(),
    }, os.path.join(args.output_dir, "path/to/your/file"))
    
    print(f"Training complete. Beta parameters saved to {args.output_dir}/path/to/your/file")

    with open(os.path.join(args.output_dir, "path/to/your/file"), "w") as f:
        for i, beta in enumerate(model.betas.detach().cpu().numpy()):
            f.write(f"Layer {i}: {beta:.6f}\n")

    print("Evaluating model on test dataset...")
    test_loss, test_accuracy = evaluate_model(model, test_dataset, tokenizer, device)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")

    with open(os.path.join(args.output_dir, "path/to/your/file"), "w") as f:
        f.write(f"Test Loss: {test_loss:.4f}\n")
        f.write(f"Test Accuracy: {test_accuracy:.4f}\n")


if __name__ == "__main__":
    main()