# train_galore_official.py
import torch
import argparse
import os
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    Trainer, TrainingArguments, set_seed
)
from datasets import load_dataset
from galore_optimizer import AdamW as GaLoreAdamW

from utils import set_train_seed
from model import freeze_bert_layers
from dataset_prep import (
    dataset_fields, dataset_to_num_labels, dataset_best_metrics,
    compute_metrics_with_args
)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="mrpc",
                        choices=["cola", "sst2", "mrpc", "qqp", "mnli", "qnli", "rte", "wnli"])
    parser.add_argument("--model_name", type=str, default="bert-base-uncased")
    parser.add_argument("--n_last_layers", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--lr", type=float, default=2e-5)
    parser.add_argument("--rank", type=int, default=32)
    parser.add_argument("--update_proj_gap", type=int, default=50)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--output_dir", type=str, default="runs-galore-official")
    return parser.parse_args()

def main():
    args = parse_args()
    set_seed(args.seed)
    set_train_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(args.output_dir, exist_ok=True)

    dataset = load_dataset("glue", args.dataset)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)

    def tokenize_function(examples):
        fields = dataset_fields[args.dataset]
        if len(fields) == 2:
            return tokenizer(examples[fields[0]], examples[fields[1]], truncation=True, padding="max_length", max_length=512)
        else:
            return tokenizer(examples[fields[0]], truncation=True, padding="max_length", max_length=512)

    tokenized = dataset.map(tokenize_function, batched=True)
    cols = ["input_ids", "attention_mask", "label"]
    if "token_type_ids" in tokenized["train"].column_names:
        cols.insert(2, "token_type_ids")
    tokenized.set_format("torch", columns=cols)
    num_labels = dataset_to_num_labels[args.dataset]

    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_name, num_labels=num_labels
    ).to(device)

    freeze_bert_layers(model, args.n_last_layers, args.model_name)

    unfrozen_linears = [
        m for n, m in model.named_modules()
        if isinstance(m, torch.nn.Linear) and any(p.requires_grad for p in m.parameters())
    ]

    galore_params = [
        m.weight for m in unfrozen_linears
        if hasattr(m, 'weight') and m.weight.requires_grad and m.weight.dim() >= 2
    ]

    galore_param_ids = {id(p) for p in galore_params}
    regular_params = [
        p for p in model.parameters()
        if p.requires_grad and id(p) not in galore_param_ids
    ]

    param_groups = []
    if regular_params:
        param_groups.append({"params": regular_params})
    if galore_params:
        param_groups.append({
            "params": galore_params,
            "rank": args.rank,
            "update_proj_gap": args.update_proj_gap,
            "scale": 4.0,
            "proj_type": "std"
        })

    optimizer = GaLoreAdamW(param_groups, lr=args.lr)

    training_args = TrainingArguments(
        output_dir=args.output_dir,
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=1,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        num_train_epochs=args.epochs,
        learning_rate=args.lr,
        weight_decay=0.01,
        warmup_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model=dataset_best_metrics.get(args.dataset, "accuracy"),
        logging_steps=50,
        report_to="none",
        seed=args.seed,
    )

    eval_split = "validation_matched" if args.dataset == "mnli" else "validation"
    
    for name, param in model.named_parameters():
        status = "❄️ Frozen" if not param.requires_grad else "🔥 Unfrozen"
        print(f"{name}: {status}")

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized["train"],
        eval_dataset=tokenized[eval_split],
        compute_metrics=compute_metrics_with_args(args=args),
        optimizers=(optimizer, None),
    )

    print(f"\nStarting GaLore training on {args.dataset} | Last {args.n_last_layers} layers unfrozen")
    trainer.train()

if __name__ == "__main__":
    main()