import os
import torch
import yaml
import torch.nn as nn
from utils.hessian_compute import power_iteration_total, estimate_eigenvalues_lanczos, power_iteration, hessian_vector_product_with_diagonal, compute_rayleigh_quotient, cosine_similarity, compute_rayleigh_quotient_update
from utils.precond_optimizer import compute_adam_preconditioner
from utils.utils import save_variable
from tqdm import tqdm


def save_config(config, save_dir):

    os.makedirs(save_dir, exist_ok=True)  
    config_path = os.path.join(save_dir, "config.yaml")
    with open(config_path, "w") as f:
        yaml.dump(config, f, default_flow_style=False)
    print(f"saving_config: {config_path}")

def evaluate(model, test_iter, device):

    model.eval()
    total_loss = 0.0
    criterion = nn.MSELoss()
    
    with torch.no_grad():
        for X, y in test_iter:
            X, y = X.to(device), y.to(device)
            outputs = model(X)
            loss = criterion(outputs, y)
            total_loss += loss.item()
    
    avg_loss = total_loss / len(test_iter)
    # print(f"Test Loss: {avg_loss:.4f}")
    return avg_loss

def train(model, train_iter, device, save_dir, config, test_iter=None):
    print("training on device: ", device)
    os.makedirs(save_dir, exist_ok=True)
    save_config(config, save_dir)
    model.to(device)
    n_epochs = config["epochs"]
    

    if config["optimizer"] == "Adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"], 
                                    weight_decay=config["weight_decay"], eps=float(config["eps"]), 
                                    betas=tuple(config["betas"]))
    elif config["optimizer"] == "AdamW":
        optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"], 
                                     weight_decay=config["weight_decay"], eps=config["eps"], 
                                     betas=tuple(config["betas"]))
    elif config["optimizer"] == "SGD":
        optimizer = torch.optim.SGD(model.parameters(), lr=config["learning_rate"], 
                                   momentum=0.0, weight_decay=config["weight_decay"])
    else:
        raise ValueError(f"Unknown optimizer: {config['optimizer']}")
    

    is_adam = config["optimizer"] in ["Adam", "AdamW"]
    
    criterion = nn.MSELoss()
    hessian_interval = config["hessian_interval"]
    param_interval = config["param_interval"]

 
    epoch_list, loss_list, hessian_eig_list, hessian_vt_eig_list, grad_norm_list, test_loss_list = [], [], [], [], [], []
    param_vectors, update_vectors, grad_vectors = [], [], []
    
    hessian_vt_eig_grad_list = []
    hessian_vt_eig_update_list = []
    # hessian_eig_list_power, hessian_vt_eig_list_power = [], []
    

    mt_vectors, vt_vectors, hat_mt_vectors, hat_vt_vectors = [], [], [], [] if is_adam else None
    
    cosine_similarity_grad_eigen_list_H = []
    cosine_similarity_grad_eigen_list_H_hat = []
    
    cosine_similarity_update_eigen_list_H = []
    cosine_similarity_update_eigen_list_H_hat = []
    
    cosine_similarity_eigen_eigen_list_H = []
    cosine_similarity_eigen_eigen_list_H_hat = []
    
    cosine_similarity_eigen_H_eigen_H_hat_list = []
    
    # cosine_similarity_mt_eigen_list = []
    
    # preconditioner_list = []

    prev_params = [p.clone().detach().cpu() for p in model.parameters()]
    v_initial_H, v_initial_diag = None, None

    for epoch in tqdm(range(n_epochs)):
        epoch_list.append(epoch)
        for X, y in train_iter:
            X, y = X.to(device), y.to(device)
            model.train()
            optimizer.zero_grad()
            outputs = model(X)
            loss = criterion(outputs, y)
            loss.backward()


            total_grad_norm = torch.norm(torch.cat([p.grad.view(-1) for p in model.parameters() if p.grad is not None]), p=2)
            grad_norm_list.append(total_grad_norm.item())
            

            grad_vector = torch.cat([p.grad.detach().flatten().cpu() for p in model.parameters() if p.grad is not None])


            if is_adam:
                mt_list, vt_list = [], []
                for p in model.parameters():
                    state = optimizer.state[p]
                    mt_list.append(state.get('exp_avg', torch.zeros_like(p)).detach())
                    vt_list.append(state.get('exp_avg_sq', torch.zeros_like(p)).detach().flatten().cpu())
                
                # mt_vector = torch.cat(mt_list)
                vt_vector = torch.cat(vt_list)
                
                # bias correction
                step_count = next(iter(optimizer.state.values()), {}).get('step', epoch + 1)
                beta1, beta2 = optimizer.defaults.get('betas', (0.9, 0.999))
                # hat_mt_vector = mt_vector / (1 - beta1 ** step_count)
                hat_vt_vector = vt_vector / (1 - beta2 ** step_count)
            else:
                mt_vector = vt_vector = hat_mt_vector = hat_vt_vector = None


            # optimizer.step()

            new_params = [p.clone().detach().cpu() for p in model.parameters()]
            param_vector = torch.cat([p.flatten() for p in new_params])
            old_param_vector = torch.cat([p.flatten() for p in prev_params])
            update_vector = param_vector - old_param_vector


            if epoch % hessian_interval == 0:
                model.eval()                
                preconditioner = compute_adam_preconditioner(optimizer)
                 
                
                eig, eig_diag = power_iteration_total(model, X, y, vt=hat_vt_vector, num_iters=500, 
                                                     eps=1e-8, tol=1e-6, device=device, 
                                                     v_initial=v_initial_H, diag_initial=v_initial_diag)
                (max_eigen_H, v_H), (max_eigen_diag, v_diagH) = eig, eig_diag
                
                cosine_similarity_eigen_H_eigen_H_hat_list.append(cosine_similarity(v_H, v_diagH))
                
                with torch.no_grad():
  
                    cosine_similarity_value = cosine_similarity(v_H, v_initial_H)
                    cosine_similarity_eigen_eigen_list_H.append(cosine_similarity_value)
                    
                    v_initial_H = v_H.detach().clone() if v_H is not None else None
                    hessian_eig_list.append(max_eigen_H)
                
                with torch.no_grad():

                    cosine_similarity_value = cosine_similarity(v_diagH, v_initial_diag)
                    cosine_similarity_eigen_eigen_list_H_hat.append(cosine_similarity_value)
                    
                    v_initial_diag = v_diagH.detach().clone() if v_diagH is not None else None
                    hessian_vt_eig_list.append(max_eigen_diag)


                grads = [p.grad.data.clone() for p in model.parameters() if p.requires_grad and p.grad is not None]
                hvp = hessian_vector_product_with_diagonal(model, criterion, X, y, vector=grads, precond=preconditioner, create_graph=True, device=device, specified_layers=None)
                grad_eigenvalue = compute_rayleigh_quotient(hvp, grads)
                hessian_vt_eig_grad_list.append(grad_eigenvalue)
                

                def unflatten_params(flattened_list, param_shapes):
                    unflattened_list = []
                    start = 0
                    for shape in param_shapes:
                        size = torch.prod(torch.tensor(shape, device=device)).item() 
                        unflattened_tensor = flattened_list[start:start + int(size)].reshape(shape).to(device)
                        unflattened_list.append(unflattened_tensor)
                        start += int(size)
                    return unflattened_list
                param_shapes = [p.shape for p in model.parameters()]
                update_vector_unflatten = unflatten_params(update_vector, param_shapes)
                # update_vector = torch.cat([p.flatten() for p in model.parameters() if p.requires_grad])
                hvp = hessian_vector_product_with_diagonal(model, criterion, X, y, vector=update_vector_unflatten, precond=preconditioner, create_graph=True, device=device, specified_layers=None)
                # update_eigenvalue = compute_rayleigh_quotient(hvp, update_vector_unflatten)
                update_eigenvalue = compute_rayleigh_quotient_update(hvp, update_vector_unflatten, grads)
                hessian_vt_eig_update_list.append(update_eigenvalue)
                

                cosine_similarity_value = cosine_similarity(grads, v_H)
                cosine_similarity_grad_eigen_list_H.append(cosine_similarity_value)
                
                
                cosine_similarity_value = cosine_similarity(grads, v_diagH)
                cosine_similarity_grad_eigen_list_H_hat.append(cosine_similarity_value)
                
                
                # print(mt_list[0].device, v_H.device)
                # def unflatten_params(flattened_list, param_shapes):
                #     unflattened_list = []
                #     start = 0
                #     for shape in param_shapes:
                #         size = torch.prod(torch.tensor(shape, device=device)).item()  
                #         unflattened_tensor = flattened_list[start:start + int(size)].reshape(shape).to(device)
                #         unflattened_list.append(unflattened_tensor)
                #         start += int(size)
                #     return unflattened_list
                # param_shapes = [p.shape for p in model.parameters()]
                # hat_mt_vector_unflatten = unflatten_params(hat_mt_vector, param_shapes)
                # cosine_similarity_value = cosine_similarity(mt_list, v_H)
                # cosine_similarity_mt_eigen_list.append(cosine_similarity_value)

                cosine_similarity_value = cosine_similarity(update_vector_unflatten, v_H)
                cosine_similarity_update_eigen_list_H.append(cosine_similarity_value)

                cosine_similarity_value = cosine_similarity(update_vector_unflatten, v_diagH)
                cosine_similarity_update_eigen_list_H_hat.append(cosine_similarity_value)
                

            optimizer.step()
            
            if epoch % param_interval == 0:
                param_vectors.append(param_vector.cpu())
                update_vectors.append(update_vector.cpu())
                grad_vectors.append(grad_vector.cpu())
                
                if is_adam:
                    # mt_vectors.append(mt_vector.cpu())
                    vt_vectors.append(vt_vector.cpu())
                    # hat_mt_vectors.append(hat_mt_vector.cpu())
                    hat_vt_vectors.append(hat_vt_vector.cpu())


            del param_vector, update_vector, grad_vector
            # if is_adam:
                # del mt_vector, vt_vector, hat_mt_vector, hat_vt_vector
            torch.cuda.empty_cache()

            prev_params = new_params
            loss_list.append(loss.item())
            
            if test_iter is not None:
                test_loss = evaluate(model, test_iter, device)
                test_loss_list.append(test_loss)
                print(f"Epoch {epoch}, Test Loss: {test_loss:.4f}")


    save_variable(torch.tensor(epoch_list), os.path.join(save_dir, "epoch_list.pkl"))
    save_variable(torch.tensor(loss_list), os.path.join(save_dir, "loss_list.pkl"))
    save_variable(torch.tensor(hessian_eig_list), os.path.join(save_dir, "hessian_eig_list.pkl"))
    save_variable(torch.tensor(hessian_vt_eig_list), os.path.join(save_dir, "hessian_vt_eig_list.pkl"))
    save_variable(torch.tensor(hessian_vt_eig_grad_list), os.path.join(save_dir, "hessian_vt_eig_grad_list.pkl"))
    save_variable(torch.tensor(hessian_vt_eig_update_list), os.path.join(save_dir, "hessian_vt_eig_update_list.pkl"))
    
    save_variable(torch.tensor(grad_norm_list), os.path.join(save_dir, "grad_norm_list.pkl"))
    
    save_variable(torch.tensor(cosine_similarity_grad_eigen_list_H), os.path.join(save_dir, "cosine_similarity_grad_eigen_list_H.pkl"))
    save_variable(torch.tensor(cosine_similarity_grad_eigen_list_H_hat), os.path.join(save_dir, "cosine_similarity_grad_eigen_list_H_hat.pkl"))
    save_variable(torch.tensor(cosine_similarity_eigen_eigen_list_H), os.path.join(save_dir, "cosine_similarity_eigen_eigen_list_H.pkl"))
    save_variable(torch.tensor(cosine_similarity_eigen_eigen_list_H_hat), os.path.join(save_dir, "cosine_similarity_eigen_eigen_list_H_hat.pkl"))
    # save_variable(torch.tensor(cosine_similarity_mt_eigen_list), os.path.join(save_dir, "cosine_similarity_mt_eigen_list.pkl"))
    save_variable(torch.tensor(cosine_similarity_update_eigen_list_H), os.path.join(save_dir, "cosine_similarity_update_eigen_list_H.pkl"))
    save_variable(torch.tensor(cosine_similarity_update_eigen_list_H_hat), os.path.join(save_dir, "cosine_similarity_update_eigen_list_H_hat.pkl"))
    save_variable(torch.tensor(cosine_similarity_eigen_H_eigen_H_hat_list), os.path.join(save_dir, "cosine_similarity_eigen_H_eigen_H_hat_list.pkl"))
    
    # save_variable(torch.tensor(hessian_eig_list_power), os.path.join(save_dir, "hessian_eig_list_power.pkl"))
    # save_variable(hessian_vt_eig_list_power, os.path.join(save_dir, "hessian_vt_eig_list_power.pkl"))
    save_variable(torch.stack(param_vectors), os.path.join(save_dir, "param_vectors.pkl"))
    save_variable(torch.stack(update_vectors), os.path.join(save_dir, "update_vectors.pkl"))
    save_variable(torch.stack(grad_vectors), os.path.join(save_dir, "grad_vectors.pkl"))
    
    if test_iter is not None:
        save_variable(torch.tensor(test_loss_list), os.path.join(save_dir, "test_loss_list.pkl"))


    if is_adam:
        # save_variable(torch.stack(mt_vectors), os.path.join(save_dir, "mt_vectors.pkl"))
        # save_variable(torch.stack(vt_vectors), os.path.join(save_dir, "vt_vectors.pkl"))
        # save_variable(torch.stack(hat_mt_vectors), os.path.join(save_dir, "hat_mt_vectors.pkl"))
        save_variable(torch.stack(hat_vt_vectors), os.path.join(save_dir, "hat_vt_vectors.pkl"))
        # save_variable(preconditioner_list, os.path.join(save_dir, "preconditioner_list.pkl"))

    return model