import argparse
import os
import random
import numpy as np
import torch
import ast
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments
)
from torch import nn
from torch.utils.data import DataLoader, Dataset

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def generate_synthetic_example(seq_length=10, K=5, noise_range=100, idx=None):
    tokens = []
    s0 = random.randint(0, K - 1)
    d = random.randint(1, K - 1)
    signals = []
    
    for t in range(seq_length):
        s_t = (s0 + t * d) % K
        signals.append(s_t)

        noise = random.randint(0, noise_range - 1)
        token = f"S{s_t}_N{noise}"
        tokens.append(token)

    context = " ".join(tokens[:-1])
    target = str(signals[-1])
    
    example = {"context": context, "target": target}
    if idx is not None:
        example["id"] = str(idx)
    return example

def generate_synthetic_dataset(num_examples, seq_length=10, K=5, noise_range=100):
    return [generate_synthetic_example(seq_length, K, noise_range, idx=i) for i in range(num_examples)]

class SyntheticDataset(Dataset):
    def __init__(self, examples):
        self.examples = examples

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]

def collate_fn(batch, tokenizer):
    inputs = []
    labels = []
    ids = []
    raw_contexts = []
    raw_targets = []

    for i, item in enumerate(batch):
        context = item["context"]
        target = item["target"]
        sample_id = item.get("id", str(i))
        raw_contexts.append(context)
        raw_targets.append(target)
        
        prompt = context  
        combined = prompt + " " + target  
        
        tokenized_combined = tokenizer(
            combined,
            add_special_tokens=True,
            return_tensors=None
        )
        tokenized_prompt = tokenizer(
            prompt,
            add_special_tokens=True,
            return_tensors=None
        )

        prompt_len = len(tokenized_prompt["input_ids"])
        input_ids = tokenized_combined["input_ids"]
        sample_labels = [-100] * prompt_len + input_ids[prompt_len:]

        inputs.append(input_ids)
        labels.append(sample_labels)
        ids.append(sample_id)

    batch_enc = tokenizer.pad(
        {"input_ids": inputs},
        padding=True,
        return_attention_mask=True,
        return_tensors="pt",
    )

    max_len = batch_enc["input_ids"].size(1)
    labels_padded = [l + [-100] * (max_len - len(l)) for l in labels]
    batch_enc["labels"] = torch.tensor(labels_padded, dtype=torch.long)
    return batch_enc

class GPT2WithResidualBeta(nn.Module):
    def __init__(self, base_model: nn.Module):
        super().__init__()
        self.base_model = base_model

        for p in self.base_model.parameters():
            p.requires_grad = False

        self.num_layers = len(self.base_model.transformer.h)
        self.betas = nn.Parameter(torch.ones(self.num_layers))

        self._block_inputs = [None] * self.num_layers

        self.hooks = []
        self._register_hooks()

    def _register_hooks(self):
        def make_pre_hook(layer_idx: int):
            def pre_hook(module, inputs):
                self._block_inputs[layer_idx] = inputs[0]
            return pre_hook

        def make_post_hook(layer_idx: int):
            def post_hook(module, inputs, output):
                x = self._block_inputs[layer_idx]

                if isinstance(output, tuple):
                    y = output[0]
                    extras = output[1:]
                else:
                    y = output
                    extras = None

                beta = self.betas[layer_idx].to(dtype=y.dtype, device=y.device)

                y_new = x + beta * (y - x)

                if extras is None:
                    return y_new
                return (y_new, *extras)
            return post_hook

        for i, block in enumerate(self.base_model.transformer.h):
            h1 = block.register_forward_pre_hook(make_pre_hook(i))
            h2 = block.register_forward_hook(make_post_hook(i))
            self.hooks.extend([h1, h2])

    def remove_hooks(self):
        for h in self.hooks:
            try:
                h.remove()
            except Exception:
                pass
        self.hooks = []

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        return self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            **kwargs
        )

    def print_beta_values(self):
        with torch.no_grad():
            betas = self.betas.detach().cpu().numpy()
        print("Beta values:")
        for i, b in enumerate(betas):
            print(f"Layer {i}: {b:.4f}")
        print("-" * 30)

def evaluate_model(model, tokenizer, device, examples):
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)

    model = model.to(device)
    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0
    
    processed_examples = examples
    print(f"Processed {len(processed_examples)} test examples")

    if processed_examples:
        print("Example processed test item:")
        print(f"Prompt: {processed_examples[0]['prompt'][:100]}...")
        print(f"Target: {processed_examples[0]['target_text']}")
    
    with torch.no_grad():
        for batch_idx in range(0, len(processed_examples), 16):
            batch = processed_examples[batch_idx:batch_idx + 16]

            inputs = []
            labels = []
            
            for item in batch:
                prompt = item["prompt"]
                combined = item["combined"]
                
                tokenized_combined = tokenizer(combined, add_special_tokens=True, return_tensors=None)
                tokenized_prompt = tokenizer(prompt, add_special_tokens=True, return_tensors=None)
                
                prompt_len = len(tokenized_prompt["input_ids"])
                input_ids = tokenized_combined["input_ids"]
                sample_labels = [-100] * prompt_len + input_ids[prompt_len:]
                
                inputs.append(input_ids)
                labels.append(sample_labels)
            
            batch_enc = tokenizer.pad(
                {"input_ids": inputs},
                padding=True,
                return_attention_mask=True,
                return_tensors="pt",
            )
            
            max_len = batch_enc["input_ids"].size(1)
            labels_padded = [l + [-100] * (max_len - len(l)) for l in labels]
            batch_enc["labels"] = torch.tensor(labels_padded, dtype=torch.long)
            
            input_ids = batch_enc["input_ids"].to(device)
            attention_mask = batch_enc["attention_mask"].to(device)
            labels = batch_enc["labels"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            batch_loss = outputs.loss.item()
            total_loss += batch_loss * len(batch)
            
            logits = outputs.logits  
            first_label_positions = (labels != -100).float().argmax(dim=1) 
            
            for i in range(logits.size(0)):
                pos = first_label_positions[i].item()
                if pos > 0 and pos < labels.size(1) and labels[i, pos] != -100:  
                    pred_token = logits[i, pos - 1].argmax(dim=-1) 
                    true_token = labels[i, pos]
                    
                    if pred_token == true_token:
                        correct_predictions += 1
                    total_samples += 1
            
            if batch_idx == 0:
                print(f"First batch example:")
                example_idx = 0
                example_prompt = tokenizer.decode(input_ids[example_idx][:first_label_positions[example_idx].item()])
                example_true_token = tokenizer.decode(labels[example_idx, first_label_positions[example_idx].item()].unsqueeze(0))
                example_pred_token = tokenizer.decode(logits[example_idx, first_label_positions[example_idx].item() - 1].argmax(dim=-1).unsqueeze(0))
                print(f"Prompt:\n{example_prompt}")
                print(f"True next token: '{example_true_token}'")
                print(f"Predicted next token: '{example_pred_token}'")
                print("-" * 30)
    
    avg_loss = total_loss / len(processed_examples) if processed_examples else 0
    accuracy = correct_predictions / total_samples if total_samples > 0 else 0
    
    print(f"Total correct: {correct_predictions}, Total samples: {total_samples}")
    print(f"Final Accuracy: {accuracy:.4f}")
    
    return avg_loss, accuracy

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--checkpoint_dir", type=str, required=True)
    parser.add_argument("--output_dir", type=str, default="path/to/your/folder")
    args = parser.parse_args()

    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    os.makedirs(args.output_dir, exist_ok=True)

    config_path = os.path.join(args.output_dir, "path/to/your/file")
    with open(config_path, "w") as f:
        for arg in vars(args):
            f.write(f"{arg}: {getattr(args, arg)}\n")

    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    train_examples = generate_synthetic_dataset(10000, seq_length=10, K=17, noise_range=100)
    train_dataset = SyntheticDataset(train_examples)

    # Load model from checkpoint
    print(f"Loading model from checkpoint: {args.checkpoint_dir}")
    base_model = AutoModelForCausalLM.from_pretrained(args.checkpoint_dir)
    base_model = base_model.to(device)

    print("Evaluating model on test dataset...")

    processed_train_examples = []
    for ex in train_examples:
        prompt = ex["context"]
        combined = prompt + " " + ex["target"]
        processed_train_examples.append({
            "prompt": prompt,
            "combined": combined,
            "target_text": ex["target"]
        })

    test_loss, test_accuracy = evaluate_model(base_model, tokenizer, device, processed_train_examples)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")

    with open(os.path.join(args.output_dir, "path/to/your/file"), "w") as f:
        f.write(f"Test Loss: {test_loss:.4f}\n")
        f.write(f"Test Accuracy: {test_accuracy:.4f}\n")

    model = GPT2WithResidualBeta(base_model)
    model.to(device)

    print("Initial beta values:")
    model.print_beta_values()

    training_args = TrainingArguments(
        seed=args.seed,
        data_seed=args.seed,
        output_dir=args.output_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        learning_rate=args.lr,
        weight_decay=0.01,
        remove_unused_columns=False,
        report_to=["none"],
        save_strategy="no"
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=lambda b: collate_fn(b, tokenizer),
    )

    trainer.train()
    
    print("Final beta values after training:")
    model.print_beta_values()

    torch.save({
        "betas": model.betas.detach().cpu(),
    }, os.path.join(args.output_dir, "path/to/your/file"))
    
    print(f"Training complete. Beta parameters saved to {args.output_dir}/path/to/your/file")

    with open(os.path.join(args.output_dir, "path/to/your/file"), "w") as f:
        for i, beta in enumerate(model.betas.detach().cpu().numpy()):
            f.write(f"Layer {i}: {beta:.6f}\n")

    print("Evaluating model on test dataset...")
    test_loss, test_accuracy = evaluate_model(model, tokenizer, device, processed_train_examples)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")

    with open(os.path.join(args.output_dir, "path/to/your/file"), "w") as f:
        f.write(f"Test Loss: {test_loss:.4f}\n")
        f.write(f"Test Accuracy: {test_accuracy:.4f}\n")


if __name__ == "__main__":
    main()