import torch
import tqdm
import sys
def trainer_mlp(train_loader,test_loader,model,optimizer,criterion,args):
    # Training loop
    train_acc = []
    test_acc = []
    layer_weights = {layer_name: [] for layer_name, layer in model.named_modules() if hasattr(layer, 'weight')}
    layer_grads = {layer_name: [] for layer_name, layer in model.named_modules() if hasattr(layer, 'weight')}
    every_epoch = 50
    for epoch in tqdm.tqdm(range(args.epochs), desc="Training Epochs"):
        model.train()
        for inputs, labels in train_loader:
            if args.optimizer != 'lbfgs':
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
            else:
                def closure():
                    optimizer.zero_grad()
                    output = model(inputs)
                    loss = criterion(output, labels)
                    loss.backward()
                    return loss
                loss = optimizer.step(closure)
        # Evaluate the model on training data
        if epoch % every_epoch == 0:
            for layer_name, layer in model.named_modules():
                if hasattr(layer, 'weight'):
                    if layer.weight.grad is not None:
                        layer_grads[layer_name].append(layer.weight.grad.detach().cpu().numpy())
                    else:
                        layer_grads[layer_name].append(None)
                    layer_weights[layer_name].append(layer.weight.detach().cpu().numpy())
        model.eval()
        train_correct = 0
        train_total = 0
        with torch.no_grad():
            for inputs, labels in train_loader:
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()
        
        train_accuracy = train_correct / train_total
        train_acc.append(train_accuracy)
        
        
        # Evaluate the model on test data
        test_correct = 0
        test_total = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                test_total += labels.size(0)
                test_correct += (predicted == labels).sum().item()
        
        test_accuracy = test_correct / test_total
        test_acc.append(test_accuracy)
        if args.print:
            print(f"Epoch {epoch+1}/{args.epochs}, Loss: {loss.item():.4f}, Train Accuracy: {train_accuracy:.4f}, NM Test Accuracy: {test_accuracy:.4f}")
    results = {
    'train_acc': train_acc,
    'test_acc': test_acc,
    'layer_weights': layer_weights,
    'layer_grads': layer_grads,
    'model_state_dict': model.state_dict(),  # Saving model's state_dict is more efficient
    'model_config' : vars(args)
    }
    return results


def trainer_transformer(train_loader, test_loader, model, optimizer, criterion, args):
    # Training loop
    train_acc = []
    test_acc = []
    layer_weights = {layer_name: [] for layer_name, layer in model.named_modules() if hasattr(layer, 'weight')}
    layer_grads = {layer_name: [] for layer_name, layer in model.named_modules() if hasattr(layer, 'weight')}
    every_epoch = 50
    
    for epoch in tqdm.tqdm(range(args.epochs), desc="Training Epochs"):
        model.train()
        correct_train = 0
        total_train = 0

        for inputs, targets in train_loader:
            optimizer.zero_grad()
            memory = inputs.clone()
            outputs = model(inputs, memory)
            last_prediction = outputs[:, -1, :]
            loss = criterion(last_prediction, targets)
            loss.backward()
            optimizer.step()
        
        # Save layer gradients and weights every few epochs
        if epoch % every_epoch == 0:
            for layer_name, layer in model.named_modules():
                if hasattr(layer, 'weight'):
                    if layer.weight.grad is not None:
                        layer_grads[layer_name].append(layer.weight.grad.detach().cpu().numpy())
                    else:
                        layer_grads[layer_name].append(None)
                    layer_weights[layer_name].append(layer.weight.detach().cpu().numpy())
        
        # Evaluate the model on training data
        model.eval()
        with torch.no_grad():
            for inputs, targets in train_loader:
                memory = inputs.clone()
                outputs = model(inputs, memory)
                last_prediction = outputs[:, -1, :]
                _, predicted = torch.max(last_prediction, dim=-1)
                correct_train += (predicted == targets).sum().item()
                total_train += targets.numel()

        train_accuracy = correct_train / total_train
        train_acc.append(train_accuracy)

        # Evaluate the model on test data
        correct_test = 0
        total_test = 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                memory = inputs.clone()
                outputs = model(inputs, memory)
                last_prediction = outputs[:, -1, :]
                _, predicted = torch.max(last_prediction, dim=-1)
                correct_test += (predicted == targets).sum().item()
                total_test += targets.numel()

        test_accuracy = correct_test / total_test
        test_acc.append(test_accuracy)

        if args.print:
            print(f"Epoch {epoch+1}/{args.epochs}, Loss: {loss.item():.4f}, Train Accuracy: {train_accuracy:.4f}, Test Accuracy: {test_accuracy:.4f}")

    results = {
        'train_acc': train_acc,
        'test_acc': test_acc,
        'layer_weights': layer_weights,
        'layer_grads': layer_grads,
        'model_state_dict': model.state_dict(),  # Save model's state_dict for efficiency
        'model_config' : vars(args)
    }
    return results

# Define the HVP computation
def hvp(loss, params, vector, retain_graph=False):
    grads = torch.autograd.grad(loss, params, create_graph=True, retain_graph=retain_graph)
    flat_grads = torch.cat([g.flatten() for g in grads]).to(vector.device)
    hv_product = torch.autograd.grad(flat_grads @ vector, params, retain_graph=retain_graph)
    return torch.cat([g.flatten() for g in hv_product])

# Power iteration to approximate the largest eigenvalue
def power_iteration_hvp(loss, params, num_iters=10):
    n_params = sum(p.numel() for p in params)
    vector = torch.randn(n_params, device=params[0].device)
    vector = vector / torch.norm(vector)

    for _ in range(num_iters):
        hv_product = hvp(loss, params, vector, retain_graph=True)  # Retain graph during iteration
        vector = hv_product / torch.norm(hv_product)

    # Compute Rayleigh quotient as the largest eigenvalue
    max_eigenvalue = (vector @ hvp(loss, params, vector, retain_graph=True)).item()
    return max_eigenvalue

def trainer_mlp_hessian(train_loader, test_loader, model, optimizer, criterion, args):
    # Training loop
    train_acc = []
    test_acc = []
    layer_weights = {layer_name: [] for layer_name, layer in model.named_modules() if hasattr(layer, 'weight')}
    layer_grads = {layer_name: [] for layer_name, layer in model.named_modules() if hasattr(layer, 'weight')}
    layer_hessian_eigs = {layer_name: [] for layer_name, layer in model.named_modules() if hasattr(layer, 'weight')}
    layer_ranks = {layer_name: [] for layer_name, layer in model.named_modules() if hasattr(layer, 'weight')}
    layer_ranks['EW'] = []
    
    every_epoch = 50
    
    for epoch in tqdm.tqdm(range(args.epochs), desc="Training Epochs"):
        model.train()
        for inputs, labels in train_loader:
            if args.optimizer != 'lbfgs':
                optimizer.zero_grad()
                model.zero_grad()  # Reset gradients
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                # print(f"Epoch {epoch + 1}, Max Eigenvalue of fc1: {max_eigenvalue}")
                if epoch % every_epoch == 0:
                    for layer_name, layer in model.named_modules():
                        if hasattr(layer, 'weight'):
                            if layer_name in ['embedding', 'fc2','fc3']:  
                                params_of_interest = [getattr(model, layer_name).weight]
                                max_eigenvalue = power_iteration_hvp(loss, params_of_interest, num_iters=5)
                                # print('eigggen',max_eigenvalue)
                                # Compute the largest and smallest eigenvalues for the selected layer using power iteration
                                layer_hessian_eigs[layer_name].append((max_eigenvalue))
                                
                                # Compute the rank of the selected layer
                                rank = torch.linalg.matrix_rank(layer.weight)
                                layer_ranks[layer_name].append(rank.item())
                                
                                # Calculate the multiplication of embedding and fc2 if both exist
                                if layer_name == 'fc2' and hasattr(model, 'embedding'):
                                    # For embedding E (d_model x vocab_size) and fc2 (seq_len*d_model x d_model)
                                    emb_weight = model.embedding.weight  # Shape: [vocab_size, d_model]
                                    fc2_weight = model.fc2.weight        # Shape: [d_model, seq_len*d_model]
                                    
                                    # Calculate the rank of the multiplication fc2 x E
                                    # The embedding weight is [vocab_size, d_model], so transpose it to [d_model, vocab_size]
                                    
                                    # Multiply fc2_weight [d_model, seq_len*d_model] with emb_weight_t [d_model, vocab_size]
                                    # Result will have shape [seq_len*d_model, vocab_size]
                                    combined_matrix = torch.mm(emb_weight,fc2_weight)
                                    
                                    # Compute rank of the multiplication
                                    combined_rank = torch.linalg.matrix_rank(combined_matrix)
                                    layer_ranks['EW'].append(combined_rank.item())
                loss.backward(retain_graph=True)
                optimizer.step()
            else:
                def closure():
                    optimizer.zero_grad()
                    output = model(inputs)
                    loss = criterion(output, labels)
                    loss.backward(retain_graph=True)  # Retain the graph for HVP computation
                    return loss
                loss = optimizer.step(closure)
        
        if epoch % every_epoch == 0:
            for layer_name, layer in model.named_modules():
                if hasattr(layer, 'weight'):
                    if layer.weight.grad is not None:
                        layer_grads[layer_name].append(layer.weight.grad.detach().cpu().numpy())
                    else:
                        layer_grads[layer_name].append(None)
                    layer_weights[layer_name].append(layer.weight.detach().cpu().numpy())
                    
        #
        # Evaluate training accuracy
        model.eval()
        train_correct = 0
        train_total = 0
        with torch.no_grad():
            for inputs, labels in train_loader:
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()
        
        train_accuracy = train_correct / train_total
        train_acc.append(train_accuracy)
        
        # Evaluate test accuracy
        test_correct = 0
        test_total = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                test_total += labels.size(0)
                test_correct += (predicted == labels).sum().item()
        
        test_accuracy = test_correct / test_total
        test_acc.append(test_accuracy)
        
        if args.print:
            print(f"Epoch {epoch+1}/{args.epochs}, Loss: {loss.item():.4f}, Train Accuracy: {train_accuracy:.4f}, Test Accuracy: {test_accuracy:.4f}")
    
    results = {
        'train_acc': train_acc,
        'test_acc': test_acc,
        'layer_weights': layer_weights,
        'layer_grads': layer_grads,
        'layer_hessian_eigs': layer_hessian_eigs,
        'model_state_dict': model.state_dict(),
        'layer_ranks':layer_ranks,
        'model_config': vars(args)
    }
    return results
