import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import torch.nn as nn
import argparse
import warnings
import gc
from transformers import TrainingArguments, Trainer, TrainerCallback, AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
from utils import set_train_seed
from model_compact import (
    freeze_bert_layers,
    add_activate_to_module,
    add_compression_tensor_to_module,
    get_svd_layers
)
from dataset_prep import dataset_fields, dataset_to_num_labels, dataset_best_metrics, compute_metrics_with_args
from custom_op_compact import register_filter
from compact_optimizer import AdamW

class ProjectionUpdateCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        custom_args = self.trainer.custom_args
        if not getattr(custom_args, "compress", False):
            return

        if state.global_step > 0 and state.global_step % custom_args.calib_iter == 0:
            r = custom_args.r
            model = self.trainer.model
            device = next(model.parameters()).device

            print(f"\n{'='*60}\n[CompAct] Updating P at step {state.global_step} (rank={r})\n{'='*60}")

            P_dict = {}
            for name, module in model.named_modules():
                if isinstance(module, nn.Linear) and getattr(module, "activate", False):
                    in_f = module.in_features
                    P_new = torch.randn(in_f, r, device=device) * (1 / torch.sqrt(torch.tensor(r, dtype=torch.float, device=device)))
                    P_new = P_new / P_new.norm(dim=0, keepdim=True)
                    P_dict[name] = P_new

            model.set_projection_matrix(P_dict)
            print(f"[CompAct] Updated P for {len(P_dict)} layers")
            gc.collect()
            torch.cuda.empty_cache()


class CustomTrainer(Trainer):
    def __init__(self, custom_args=None, modules_compressed=None, **kwargs):
        super().__init__(**kwargs)
        self.custom_args = custom_args
        self.modules_compressed = modules_compressed

    def create_optimizer(self):
        if self.optimizer is not None:
            return self.optimizer

        if not self.custom_args or not self.custom_args.compress:
            return super().create_optimizer()

        param_groups = []
        no_decay = ["bias", "LayerNorm.weight"]

        wrapped_modules = getattr(self.model, "wrapped_modules", {})
        
        if not wrapped_modules:
            print("WARNING: No wrapped_modules found! CompAct will not work.")
            return super().create_optimizer()
        
        print(f"\n{'='*60}")
        print(f"[CustomTrainer] Found {len(wrapped_modules)} wrapped modules for CompAct")
        print(f"{'='*60}\n")

        param_to_module = {}
        for module_name, module in self.model.named_modules():
            if module_name in wrapped_modules:
                for param_name, param in module.named_parameters(recurse=False):
                    full_param_name = f"{module_name}.{param_name}"
                    param_to_module[full_param_name] = module
                    print(f"  Mapped param '{full_param_name}' -> module {id(module)}")
                    print(f"    param id: {id(param)}, wrapped_module from dict: {id(wrapped_modules[module_name])}")

        print(f"\n[CustomTrainer] Built param_to_module mapping with {len(param_to_module)} entries")
        print(f"{'='*60}\n")

        compact_count = 0
        for name, param in self.model.named_parameters():
            if not param.requires_grad:
                continue
                
            weight_decay = 0.0 if any(nd in name for nd in no_decay) else 0.01
            
            group = {
                "params": [param],
                "weight_decay": weight_decay,
                "lr": self.custom_args.learning_rate,
            }
            
            if name in param_to_module:
                module = param_to_module[name]
                compact_count += 1
                print(f"  [{compact_count}] Setting up CompAct for param: {name}")
                print(f"      Module id: {id(module)}, has P: {hasattr(module, 'P')}, activate: {getattr(module, 'activate', False)}")
                print(f"      param id: {id(param)}")
                
                module_name = name.rsplit(".", 1)[0]
                if module_name in wrapped_modules:
                    expected_module = wrapped_modules[module_name]
                    print(f"      Expected module id from wrapped_modules: {id(expected_module)}")
                    print(f"      Match: {id(module) == id(expected_module)}")
                
                group.update({
                    "compact": True,
                    "module_name": module_name,
                    "alpha": self.custom_args.alpha,
                    "r": self.custom_args.r,
                })
            
            param_groups.append(group)

        print(f"\n[CustomTrainer] Total CompAct param groups: {compact_count}")
        print(f"{'='*60}\n")

        self.optimizer = AdamW(
            param_groups,
            lr=self.custom_args.learning_rate,
            weight_decay=0.01,
            betas=(0.9, 0.999),
            eps=1e-8,
            correct_bias=True,
            model=self.model, 
        )
        return self.optimizer


def parse_args():
    parser = argparse.ArgumentParser(description="CompAct Training on GLUE")
    parser.add_argument("--dataset", type=str, default="rte",
                        choices=["cola", "sst2", "mrpc", "qqp", "mnli", "qnli", "rte", "wnli"])
    parser.add_argument("--n-last-layers", type=int, default=1,
                        help="Number of last BERT layers to finetune")
    parser.add_argument("--calib_iter", type=int, default=50)
    parser.add_argument("--model_name", type=str, default="bert-base-uncased")
    parser.add_argument("--output_dir", type=str, default=None,
                        help="If not provided, auto-generated from dataset/rank/alpha/layers")
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--learning-rate", type=float, default=3e-5)
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--r", type=int, default=16, help="Compression rank")
    parser.add_argument("--alpha", type=float, default=2.0, help="Alpha factor")
    parser.add_argument("--compress", action="store_true", help="Enable CompAct compression")

    args = parser.parse_args()

    if args.output_dir is None:
        alpha_str = f"a{args.alpha:.1f}".replace('.0', '')
        if args.n_last_layers == 0:
            layer_str = "full"
        elif args.n_last_layers == 1:
            layer_str = "head"
        else:
            layer_str = f"{args.n_last_layers}L"
        prefix = "CompAct" if args.compress else "FullFT"

        args.output_dir = f"runs/{prefix}_{args.dataset}_r{args.r}_alpha_{alpha_str}_{layer_str}"

    return args


def main():
    args = parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_train_seed(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)
    
    print(f"\n{'='*80}")
    print(f"CompAct Training Configuration")
    print(f"{'='*80}")
    print(f"Dataset: {args.dataset}")
    print(f"Model: {args.model_name}")
    print(f"Output directory: {args.output_dir}")
    print(f"Compress: {args.compress}")
    if args.compress:
        print(f"Rank (r): {args.r}")
        print(f"Alpha: {args.alpha}")
        print(f"Calibration iter: {args.calib_iter}")
    print(f"Device: {device}")
    print(f"{'='*80}\n")

    dataset = load_dataset("glue", args.dataset)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)

    def tokenize(examples):
        fields = dataset_fields[args.dataset]
        if len(fields) == 2:
            return tokenizer(examples[fields[0]], examples[fields[1]], truncation=True, padding="max_length", max_length=512)
        return tokenizer(examples[fields[0]], truncation=True, padding="max_length", max_length=512)

    tokenized = dataset.map(tokenize, batched=True)
    cols = ["input_ids", "attention_mask", "label"]
    if "token_type_ids" in tokenized["train"].column_names:
        cols.insert(2, "token_type_ids")
    tokenized.set_format("torch", columns=cols)

    num_labels = dataset_to_num_labels[args.dataset]
    model = AutoModelForSequenceClassification.from_pretrained(args.model_name, num_labels=num_labels).to(device)
    freeze_bert_layers(model, args.n_last_layers, args.model_name)

    modules_compressed = None
    wrapped_modules = {}
    
    if args.compress:
        print(f"\n{'='*60}")
        print(f"[CompAct] Setting up compression...")
        print(f"{'='*60}\n")
        
        add_activate_to_module()
        add_compression_tensor_to_module()
        modules_compressed = get_svd_layers(model)
        print(f"[CompAct] Found {len(modules_compressed)} layers to compress")

        r = args.r
        P_dict = {}
        for name, module in model.named_modules():
            if name in modules_compressed:
                P = torch.randn(module.in_features, r, device=device)
                P = P / P.norm(dim=0, keepdim=True)
                P_dict[name] = P
        
        model.set_projection_matrix(P_dict)
        print(f"[CompAct] Initialized P for {len(P_dict)} layers with rank={r}")

        model, wrapped_modules = register_filter(model, modules_compressed)
        print(f"[CompAct] Wrapped {len(wrapped_modules)} layers")
        
        model.wrapped_modules = wrapped_modules
        
        print(f"\n[CompAct] Verifying wrapped modules:")
        for i, (name, module) in enumerate(list(wrapped_modules.items())[:3], 1):
            print(f"  {i}. {name}")
            print(f"     Type: {type(module).__name__}")
            print(f"     ID: {id(module)}")
            print(f"     Has P: {hasattr(module, 'P')}")
            print(f"     P shape: {module.P.shape if hasattr(module, 'P') else 'N/A'}")
            print(f"     Activate: {getattr(module, 'activate', False)}")
        if len(wrapped_modules) > 3:
            print(f"  ... and {len(wrapped_modules) - 3} more layers")
        print(f"{'='*60}\n")

    # Training arguments
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=1,
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        num_train_epochs=args.epochs,
        weight_decay=0.01,
        warmup_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model=dataset_best_metrics.get(args.dataset, "accuracy"),
        logging_steps=50,
        report_to="none",
        seed = args.seed
    )

    eval_split = "validation_matched" if args.dataset == "mnli" else "validation"
    # print(model)
    for name, param in model.named_parameters():
        status = "❄️ Frozen" if not param.requires_grad else "🔥 Unfrozen"
        print(f"{name}: {status}")
    # Create trainer
    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized["train"],
        eval_dataset=tokenized[eval_split],
        compute_metrics=compute_metrics_with_args(args=args),
        callbacks=[ProjectionUpdateCallback()],
        custom_args=args,
        modules_compressed=modules_compressed,
    )

    for callback in trainer.callback_handler.callbacks:
        if isinstance(callback, ProjectionUpdateCallback):
            callback.trainer = trainer

    print("\n" + "="*80)
    print("Starting CompAct training...")
    print("="*80 + "\n")
    
    trainer.train()
    trainer.save_model(args.output_dir)
    
    print("\n" + "="*80)
    print(f"Training completed! Model saved to {args.output_dir}")
    print("="*80 + "\n")


if __name__ == "__main__":
    main()