import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler
from transformers import AutoTokenizer, OPTForSequenceClassification, AdamW
from datasets import load_dataset
from tqdm import tqdm
from torch.nn.utils import parameters_to_vector, vector_to_parameters
import wandb
from peft import get_peft_model, LoraConfig, TaskType

def estimate_smoothness_opt(model, batch, h=1e-5, num_samples=100, device=None):
    """
    Estimates the smoothness of the base OPT model with respect to a batch of data from SST2.
    """
    model.eval()
    if device is None:
        device = next(model.parameters()).device
    
    # Get only the base model parameters
    base_params = [p for n, p in model.named_parameters() if not n.startswith('base_model.model.lora')]
    original_params = parameters_to_vector(base_params).detach().clone()
    
    def compute_f_and_grad(params):
        vector_to_parameters(params, base_params)
        model.zero_grad()
        outputs = model(**batch)
        f_x = outputs.logits.sum()
        f_x.backward()
        grads = parameters_to_vector([p.grad for p in base_params if p.grad is not None]).detach().clone()
        return f_x.detach(), grads
    
    _, grad_x = compute_f_and_grad(original_params)
    smoothness_estimates = []
    
    for _ in range(num_samples):
        v = torch.randn_like(original_params)
        v = v / v.norm()
        perturbed_params = original_params + h * v
        _, grad_perturbed = compute_f_and_grad(perturbed_params)
        diff = grad_perturbed - grad_x 
        estimate = torch.norm(diff) / torch.norm(h*v)
        smoothness_estimates.append(estimate.item())
    
    with torch.no_grad():
        vector_to_parameters(original_params, base_params)
    
    return sum(smoothness_estimates) / num_samples

def estimate_hessian_properties(model, batch, device, num_samples=10):
    """
    Estimates properties of the Hessian for the base model using randomized methods.
    """
    model.eval()
    base_params = [p for n, p in model.named_parameters() if not n.startswith('base_model.model.lora')]
    original_params = parameters_to_vector(base_params).detach().clone()
    
    def compute_grad(params):
        vector_to_parameters(params, base_params)
        model.zero_grad()
        outputs = model(**batch)
        loss = outputs.logits.sum()
        loss.backward()
        return parameters_to_vector([p.grad for p in base_params if p.grad is not None]).detach().clone()
    
    grad_x = compute_grad(original_params)
    
    h = 1e-5
    hessian_vector_products = []
    
    for _ in range(num_samples):
        v = torch.randn_like(original_params).to(device)
        v = v / v.norm()
        
        perturbed_params = original_params + h * v
        grad_perturbed = compute_grad(perturbed_params)
        
        hvp = (grad_perturbed - grad_x) / h
        hessian_vector_products.append(hvp)
    
    with torch.no_grad():
        vector_to_parameters(original_params, base_params)
    
    # Estimate Frobenius norm
    frob_norm_estimate = torch.mean(torch.stack([torch.norm(hvp)**2 for hvp in hessian_vector_products])).sqrt()
    
    # Estimate spectral norm (largest singular value)
    spectral_norm_estimate = torch.max(torch.stack([torch.dot(v, hvp) for v, hvp in zip(hessian_vector_products, hessian_vector_products)]))
    
    return {
        "hessian_frob_norm_estimate": frob_norm_estimate.item(),
        "hessian_spectral_norm_estimate": spectral_norm_estimate.item(),
    }

def compute_gradient_norm(model):
    """
    Computes the L2 norm of the model's gradient.
    """
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    return total_norm ** 0.5

def main():
    # Initialize wandb
    wandb.init(project="opt-fine-tuning-base-smoothness", name="sst2-base-smoothness-analysis")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load OPT-125M model and tokenizer
    model_name = "facebook/opt-125m"
    model = OPTForSequenceClassification.from_pretrained(model_name, num_labels=2)
    
    # Define LoRA Config
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.SEQ_CLS
    )
    
    # Get PEFT model
    model = get_peft_model(model, lora_config)
    
    # Enable training of all parameters
    for param in model.parameters():
        param.requires_grad = True
    
    model.print_trainable_parameters()
    
    model.to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Load SST2 dataset
    dataset = load_dataset("glue", "sst2")
    train_dataset = dataset["train"]
    
    # Tokenize dataset
    def tokenize_function(examples):
        return tokenizer(examples["sentence"], padding="max_length", truncation=True, max_length=128)
    
    tokenized_datasets = train_dataset.map(tokenize_function, batched=True, remove_columns=train_dataset.column_names)
    tokenized_datasets = tokenized_datasets.add_column("labels", train_dataset["label"])
    tokenized_datasets.set_format("torch")
    
    # Create DataLoader
    train_dataloader = DataLoader(
        tokenized_datasets, 
        sampler=RandomSampler(tokenized_datasets), 
        batch_size=8,
        collate_fn=lambda x: {key: torch.stack([sample[key] for sample in x]) for key in x[0]}
    )
    
    # Setup optimizer
    optimizer = AdamW(model.parameters(), lr=5e-5)

    # Training loop
    num_epochs = 3
    num_training_steps = num_epochs * len(train_dataloader)
    progress_bar = tqdm(range(num_training_steps))
    
    for epoch in range(num_epochs):
        model.train()
        for batch in train_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            logits = outputs.logits
            labels = batch['labels']
            loss = F.cross_entropy(logits, labels)
            loss.backward()
            
            # Compute gradient norm before optimizer step
            grad_norm = compute_gradient_norm(model)
            
            optimizer.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            
            # Log gradient norm for every step
            wandb.log({
                "step": progress_bar.n,
                "gradient_norm": grad_norm,
                "loss": loss.item()
            })
            
            # Estimate smoothness and Hessian properties every 100 steps
            if progress_bar.n % 1000 == 0:
                smoothness = estimate_smoothness_opt(model, batch, h=1e-6, num_samples=100, device=device)
                hessian_properties = estimate_hessian_properties(model, batch, device)
                
                # Log to wandb
                wandb.log({
                    "step": progress_bar.n,
                    "smoothness": smoothness,
                    **hessian_properties
                })
                
                print(f"Step {progress_bar.n}, Loss: {loss.item():.4f}, Gradient Norm: {grad_norm:.4f}, Estimated Smoothness: {smoothness:.4f}")
    
    wandb.finish()
    print("Fine-tuning complete. Results logged to wandb.")

if __name__ == '__main__':
    main()