import torch
from torch.utils.data import DataLoader
from transformers import (
    RobertaTokenizer,
    RobertaForSequenceClassification,
    AdamW,
    get_linear_schedule_with_warmup,
    TrainingArguments,
    Trainer
)
from datasets import load_dataset
from tqdm.auto import tqdm
import numpy as np
from peft import get_peft_model, LoraConfig, TaskType
import argparse
import warnings
import os
from datetime import datetime
import json
import yaml
import atexit
import wandb

from utils.data_utils import *
from models import *
from utils.misc import *
import transformers
from utils.eval_callback import create_evaluation_callback

# Import Block Hadamard HiRA modules
from block_hadamard_hira import BlockHadamardHiRAConfig, get_block_hadamard_hira_model, apply_block_hadamard_hira

import os
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

def create_run_directory(args):
    """Create a directory structure for the current training run."""
    # Create base directory for all runs
    base_dir = "experiments/block_hadamard_hira_commonsense_reasoning"
    
    # Create timestamp for unique run identification
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Create model name directory (simplified name)
    model_name = args.model.split('/')[-1]
    
    # Create run-specific directory with relevant parameters
    run_name_parts = [
        f"rank_{args.block_hira_r}", 
        f"blocks{args.num_blocks}",
        f"lr{args.lr}",
        args.block_arrangement
    ]
    
    # Add Block Hadamard HiRA-specific info to run name
    block_hira_info = f"alpha{args.block_hira_alpha}"
    if args.block_hira_dropout > 0:
        block_hira_info += f"_dropout{args.block_hira_dropout}"
    run_name_parts.append(block_hira_info)
    
    run_name = "_".join(run_name_parts)
    
    # Final directory structure: experiments/model_name/training_type/YYYYMMDD_HHMMSS_parameters
    run_dir = os.path.join(base_dir, model_name, f"{timestamp}_{run_name}")
    
    # Create directories
    os.makedirs(run_dir, exist_ok=True)
    os.makedirs(os.path.join(run_dir, "checkpoints"), exist_ok=True)
    os.makedirs(os.path.join(run_dir, "logs"), exist_ok=True)
    
    # Save run configuration
    config_dict = vars(args)
    with open(os.path.join(run_dir, "config.json"), 'w') as f:
        json.dump(config_dict, f, indent=4)
    
    return run_dir

def create_model_tokenizer_block_hira_cr(args):
    """Create model and tokenizer for Block Hadamard HiRA commonsense reasoning training."""
    from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
    
    model = AutoModelForCausalLM.from_pretrained(
        args.model, 
        device_map="auto",
        torch_dtype=torch.bfloat16
    ) 
    
    if "llama" in args.model.lower():
        if "Llama-3" in args.model:
            tokenizer = AutoTokenizer.from_pretrained(
                args.model,
                use_fast=True,
                model_max_length=args.max_seq_length,
                padding="max_length",
            )
        else:
            tokenizer = LlamaTokenizer.from_pretrained(
                args.model,
                use_fast=True,
                model_max_length=args.max_seq_length,
                padding="max_length",
            )
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            args.model,
            use_fast=True,
            model_max_length=args.max_seq_length,
            padding="max_length",
        )

    tokenizer.pad_token_id = (0)
    tokenizer.padding_side = "left"

    return model, tokenizer

def create_block_hadamard_hira_model_cr(model, args):
    """Create Block Hadamard HiRA model for commonsense reasoning."""
    
    # Configure Block Hadamard HiRA
    block_hira_config = BlockHadamardHiRAConfig(
        r=args.block_hira_r,
        alpha=args.block_hira_alpha,
        dropout=args.block_hira_dropout,
        target_modules=args.target_modules,
        bias="none",
        init_lora_weights=True,
        num_blocks=args.num_blocks,
        block_arrangement=args.block_arrangement,
        use_fast_inference=True,  # Enable optimized batched operations
    )
    
    # Apply Block Hadamard HiRA to the model
    model = get_block_hadamard_hira_model(model, block_hira_config, adapter_name="block_hira_cr")
    
    print(f"✅ Applied Block Hadamard HiRA with r={args.block_hira_r}, alpha={args.block_hira_alpha}")
    print(f"🧱 Block configuration: {args.num_blocks}×{args.num_blocks} blocks ({args.block_arrangement})")
    print(f"🎯 Target modules: {args.target_modules}")
    
    # Display performance information for the first Block Hadamard HiRA layer
    first_block_hira_layer = None
    for name, module in model.named_modules():
        if hasattr(module, 'get_performance_info'):
            first_block_hira_layer = module
            break
    
    if first_block_hira_layer:
        perf_info = first_block_hira_layer.get_performance_info()
        print(f"🚀 Performance optimizations enabled:")
        print(f"   └── Batched GEMM: {perf_info.get('uses_batched_gemm', False)}")
        print(f"   └── Vectorized Hadamard: {perf_info.get('vectorized_hadamard', False)}")
        print(f"   └── Theoretical speedup: {perf_info.get('theoretical_speedup', 'N/A')}")
        print(f"   └── Memory overhead: {perf_info.get('memory_overhead', 'N/A')}")
    
    return model, block_hira_config

def count_block_hadamard_hira_parameters(model, verbose=True):
    """Count parameters in Block Hadamard HiRA model with detailed breakdown."""
    
    total_params = 0
    trainable_params = 0
    block_hira_params = 0
    classifier_params = 0
    
    # Count by module type
    for name, param in model.named_parameters():
        param_count = param.numel()
        total_params += param_count
        
        if param.requires_grad:
            trainable_params += param_count
            
            # Check if it's a Block Hadamard HiRA parameter
            if any(hira_param in name for hira_param in ["block_lora_A", "block_lora_B"]):
                block_hira_params += param_count
            
            # Check if it's a classifier parameter
            if any(classifier_param in name for classifier_param in ["classifier", "lm_head", "output"]):
                classifier_params += param_count
    
    non_classifier_params = trainable_params - classifier_params
    
    if verbose:
        print(f"\n📊 Block Hadamard HiRA Parameter Analysis:")
        print(f"   Total parameters: {total_params:,}")
        print(f"   Trainable parameters: {trainable_params:,}")
        print(f"   Block Hadamard HiRA parameters: {block_hira_params:,}")
        print(f"   Classifier parameters: {classifier_params:,}")
        print(f"   Other trainable: {non_classifier_params:,}")
        print(f"   Trainable ratio: {trainable_params/total_params*100:.2f}%")
        print(f"   Block HiRA ratio: {block_hira_params/trainable_params*100:.2f}%")
    
    return {
        'total_params': total_params,
        'trainable_params': trainable_params,
        'block_hira_params': block_hira_params,
        'classifier_params': classifier_params,
        'non_classifier_params': non_classifier_params,
        'trainable_ratio': trainable_params/total_params*100,
        'block_hira_ratio': block_hira_params/trainable_params*100 if trainable_params > 0 else 0
    }

def finetune():
    run_dir = create_run_directory(args)
    
    # Initialize wandb with the run directory
    wandb_run_name = os.path.basename(run_dir)
    wandb_run = wandb.init(
        project="block_hadamard_hira_commonsense_reasoning",
        config=args,
        dir=os.path.join(run_dir, "logs"),
        name=wandb_run_name
    )

    # Save wandb run ID to a file
    with open(os.path.join(run_dir, "wandb_run_id.txt"), "w") as f:
        f.write(wandb_run.id)
    
    # Create model and tokenizer
    model, tokenizer = create_model_tokenizer_block_hira_cr(args)
    
    # Data handling
    train_dataset = load_and_preprocess_cr(tokenizer=tokenizer, args=args)

    data_collator = transformers.DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    )
    data_module = dict(train_dataset=train_dataset, data_collator=data_collator)
    
    # Apply Block Hadamard HiRA to model
    model, block_hira_config = create_block_hadamard_hira_model_cr(model, args)

    # Enhanced parameter analysis
    print("📊 Analyzing Block Hadamard HiRA model parameters...")
    param_counts = count_block_hadamard_hira_parameters(model, verbose=True)

    print("\n" + "="*60)
    print("🎯 Block Hadamard HiRA Parameter Analysis")
    print("="*60)
    print(f"📈 Total model parameters: {param_counts['total_params']:,}")
    print(f"🔒 Frozen parameters: {param_counts['total_params'] - param_counts['trainable_params']:,}")
    print(f"🔓 Trainable parameters: {param_counts['trainable_params']:,}")
    print(f"   └── Block Hadamard HiRA parameters: {param_counts['block_hira_params']:,}")
    print(f"   └── Classifier parameters: {param_counts['classifier_params']:,}")
    print(f"   └── Other trainable: {param_counts['non_classifier_params']:,}")
    print(f"📊 Trainable ratio: {param_counts['trainable_ratio']:.2f}%")
    print(f"🧱 Block HiRA ratio: {param_counts['block_hira_ratio']:.2f}%")
    
    print(f"\n🔧 Block Hadamard HiRA Configuration:")
    print(f"   Rank (r): {args.block_hira_r}")
    print(f"   Alpha: {args.block_hira_alpha}")
    print(f"   Dropout: {args.block_hira_dropout}")
    print(f"   Blocks: {args.num_blocks}×{args.num_blocks} ({args.block_arrangement})")
    print(f"   Target modules: {args.target_modules}")
    print("="*60)
    
    # Log Block Hadamard HiRA specific metrics
    log_dict = {
        "total_model_params": param_counts['total_params'],
        "frozen_params": param_counts['total_params'] - param_counts['trainable_params'],
        "total_trainable_params": param_counts['trainable_params'],
        "block_hira_params": param_counts['block_hira_params'],
        "classifier_params": param_counts['classifier_params'],
        "non_classifier_params": param_counts['non_classifier_params'],
        "trainable_ratio_percent": param_counts['trainable_ratio'],
        "block_hira_ratio_percent": param_counts['block_hira_ratio'],
        "block_hira_rank": args.block_hira_r,
        "block_hira_alpha": args.block_hira_alpha,
        "block_hira_dropout": args.block_hira_dropout,
        "num_blocks": args.num_blocks,
        "block_arrangement": args.block_arrangement,
        "target_modules": str(args.target_modules),
        "method": "Block_Hadamard_HiRA"
    }
    
    wandb.log(log_dict)

    # Setup optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir=os.path.join(run_dir, "checkpoints"),
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        learning_rate=args.lr,
        weight_decay=0,
        warmup_ratio=args.warmup_ratio,
        lr_scheduler_type=args.scheduler,
        seed=args.seed,
        report_to="wandb",
        gradient_accumulation_steps=args.grad_acc_steps,
        save_strategy="no",
        bf16=True,
        tf32=False,
        fp16=False,
        logging_steps=1,
        logging_first_step=True,
        logging_dir=os.path.join(run_dir, "logs"),
    )
    
    # Save training arguments
    training_args_path = os.path.join(run_dir, "training_args.json")
    with open(training_args_path, 'w') as f:
        json.dump(training_args.to_dict(), f, indent=4)
    
    # Initialize callbacks list
    callbacks = []
    
    # Add periodic evaluation callback if enabled
    if args.enable_periodic_eval:
        print(f"\n🔍 Setting up periodic evaluation...")
        eval_callback = create_evaluation_callback(
            eval_steps=args.eval_steps,
            eval_data_path=args.eval_data_path,
            base_model_name=args.model,
            run_dir=run_dir,
            early_stopping_patience=args.early_stopping_patience,
            early_stopping_threshold=args.early_stopping_threshold,
            max_eval_samples=args.max_eval_samples,
            eval_batch_size=args.eval_batch_size,
            use_content_gating=False,  # Block Hadamard HiRA doesn't use content gating
        )
        callbacks.append(eval_callback)
        
        # Log evaluation settings
        wandb.log({
            "eval/enabled": True,
            "eval/eval_steps": args.eval_steps,
            "eval/max_eval_samples": args.max_eval_samples,
            "eval/early_stopping_patience": args.early_stopping_patience,
            "eval/early_stopping_threshold": args.early_stopping_threshold,
        })
    else:
        print("⚠️ Periodic evaluation disabled")
        wandb.log({"eval/enabled": False})
    
    trainer = Trainer(
        model=model,
        args=training_args,
        **data_module,
        optimizers=(optimizer, None),
        callbacks=callbacks,
    )
    
    # Save tokenizer
    tokenizer.save_pretrained(os.path.join(run_dir, "tokenizer"))
    
    # Save Block Hadamard HiRA config
    block_hira_config_path = os.path.join(run_dir, "block_hadamard_hira_config.json")
    with open(block_hira_config_path, 'w') as f:
        json.dump(block_hira_config.__dict__, f, indent=4)
    
    # Training
    model.config.use_cache = False
    trainer.train()
    
    # After training
    final_model_path = os.path.join(run_dir, "final_model")
    trainer.save_state()
    
    # Save Block Hadamard HiRA adapter
    model.save_pretrained(final_model_path)
    
    # Also save just the Block Hadamard HiRA adapter weights
    adapter_path = os.path.join(run_dir, "block_hadamard_hira_adapter")
    os.makedirs(adapter_path, exist_ok=True)
    
    # Save Block Hadamard HiRA adapter weights
    from block_hadamard_hira import get_adapter_state_dict
    adapter_state_dict = get_adapter_state_dict(model, "block_hira_cr")
    torch.save(adapter_state_dict, os.path.join(adapter_path, "adapter_model.bin"))
    
    # Save adapter config
    with open(os.path.join(adapter_path, "adapter_config.json"), 'w') as f:
        json.dump({
            "r": args.block_hira_r,
            "alpha": args.block_hira_alpha,
            "dropout": args.block_hira_dropout,
            "target_modules": args.target_modules,
            "bias": "none",
            "init_lora_weights": True,
            "peft_type": "Block_Hadamard_HiRA",
            "num_blocks": args.num_blocks,
            "block_arrangement": args.block_arrangement
        }, f, indent=4)
    
    print(f"\n✅ Training completed!")
    print(f"📁 Run directory: {run_dir}")
    print(f"💾 Final model: {final_model_path}")
    print(f"🧱 Block Hadamard HiRA adapter: {adapter_path}")
    
    return run_dir

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description="Block Hadamard HiRA training for commonsense reasoning")
    
    # Dataset arguments
    parser.add_argument("--data_path", type=str, default="data/commonsense/commonsense_170k.json", help="Path to the training data")
    parser.add_argument('--train_on_inputs', action='store_true', help='Train on inputs')

    # Model arguments
    parser.add_argument("--model", type=str, default="meta-llama/Llama-3.2-1B", help="Model name")
    
    # Block Hadamard HiRA arguments
    parser.add_argument("--block_hira_r", type=int, default=32, help="Block Hadamard HiRA rank value per block")
    parser.add_argument("--block_hira_alpha", type=float, default=32, help="Block Hadamard HiRA alpha scaling factor")
    parser.add_argument("--block_hira_dropout", type=float, default=0.05, help="Block Hadamard HiRA dropout value")
    parser.add_argument("--num_blocks", type=int, default=4, help="Number of blocks per dimension (e.g., 4 = 4x4 blocks)")
    parser.add_argument("--block_arrangement", type=str, default="square", choices=["square"], help="Block arrangement pattern")
    parser.add_argument("--target_modules", type=str, nargs="+", 
                       default=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
                       help="Target modules for Block Hadamard HiRA adaptation")
    
    # Training arguments
    parser.add_argument("--batch_size", type=int, default=6, help="Batch size")
    parser.add_argument("--grad_acc_steps", type=int, default=24, help="Gradient accumulation steps")
    parser.add_argument("--epochs", type=int, default=2, help="Number of epochs")
    parser.add_argument("--scheduler", type=str, default="linear", help="Learning rate scheduler")
    parser.add_argument("--warmup_ratio", type=float, default=0.02, help="Warmup ratio")
    parser.add_argument("--max_seq_length", type=int, default=256, help="Maximum sequence length")
    parser.add_argument("--lr", type=float, default=2e-3, help="Learning rate")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
    
    # Periodic evaluation arguments
    parser.add_argument("--enable_periodic_eval", action="store_true", help="Enable periodic evaluation during training")
    parser.add_argument("--eval_steps", type=int, default=500, help="Evaluate every N steps")
    parser.add_argument("--eval_data_path", type=str, default="data/commonsense/boolq/test.json", help="Path to evaluation data")
    parser.add_argument("--early_stopping_patience", type=int, default=3, help="Early stopping patience (0 to disable)")
    parser.add_argument("--early_stopping_threshold", type=float, default=0.01, help="Early stopping threshold")
    parser.add_argument("--max_eval_samples", type=int, default=300, help="Maximum samples for evaluation (to save time)")
    parser.add_argument("--eval_batch_size", type=int, default=4, help="Batch size for evaluation")
    
    args = parser.parse_args()

    # Set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Validate target modules for different model types
    if "llama" in args.model.lower():
        if args.target_modules == ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]:
            print("✅ Using standard Llama target modules for Block Hadamard HiRA")
    elif "roberta" in args.model.lower():
        args.target_modules = ["query", "key", "value", "dense"]
        print("🔄 Adjusted target modules for RoBERTa model")
    else:
        print(f"⚠️ Using custom target modules: {args.target_modules}")
    
    print(f"\n🚀 Starting Block Hadamard HiRA training...")
    print(f"📊 Model: {args.model}")
    print(f"🧱 Block Hadamard HiRA config: r={args.block_hira_r}, alpha={args.block_hira_alpha}, dropout={args.block_hira_dropout}")
    print(f"🔲 Block configuration: {args.num_blocks}×{args.num_blocks} ({args.block_arrangement})")
    print(f"🎯 Target modules: {args.target_modules}")
    print(f"📚 Data: {args.data_path}")
    print(f"⚙️ Training: {args.epochs} epochs, lr={args.lr}, batch_size={args.batch_size}")

    # Run training
    run_dir = finetune()
    
    print(f"\n🎉 Block Hadamard HiRA training completed successfully!")
    print(f"📁 Results saved in: {run_dir}")