import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import time
import wandb
from sklearn.datasets import make_spd_matrix
from random import randint
from torch.optim.lr_scheduler import ExponentialLR
import argparse
import matplotlib.pyplot as plt
import csv
import os

from data_generation import generate_L, top_k_eigen, generate_X, generate_batch, generate_batch_cov, CovarianceDataset, generate_batch_scale_d
from Work.transformer_PCA.main import TransformerModel
from loss import MeanRelativeSquaredError

def parse_args():
    parser = argparse.ArgumentParser(description="Transformer-based PCA")

    parser.add_argument("--D", type=int, default=2, help="Dimension of each column vector")
    parser.add_argument("--N", type=int, default=5, help="Number of columns in each X matrix")
    parser.add_argument("--k", type=int, default=1, help="Number of top eigenvalues to use as labels")
    parser.add_argument("--n_embd", type=int, default=64, help="Embedding size for the transformer")
    parser.add_argument("--n_layer", type=int, default=12, help="Number of layers in the transformer")
    parser.add_argument("--n_head", type=int, default=4, help="Number of attention heads in the transformer")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size for training")
    parser.add_argument("--training_steps", type=int, default=60000, help="Total number of training steps")
    parser.add_argument("--n_training_data", type=int, default=1024000, help="Total number of training steps")
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
    parser.add_argument("--log_step", type=int, default=50,help="log the loss")
    parser.add_argument("--seed", type=int, default=1234,help="log the loss")
    parser.add_argument("--plot_name", type=str, help="Specify the name of the output plot file.")
    parser.add_argument("--csv_name", type=str, help="Specify the name of the output csv file.")
    parser.add_argument("--dataset", type=str, help="Specify the name of the output csv file.")
    parser.add_argument("--run_name", type=str, help="Specify the name of the output csv file.")
    parser.add_argument("--save_model_to", type=str, help="Specify the name of the output csv file.")
    parser.add_argument("--input_is_cov", action='store_true', help="Flag to specify if input is a covariance matrix")
    parser.add_argument("--predict_vector", action='store_true', help="Flag to specify if you want to predict eigenvectors") 
    parser.add_argument("--predict_cov", action='store_true', help="Flag to specify if you want to predict covariance matrix") 
    parser.add_argument("--is_relu", action='store_true', help="Flag to specify if you want relu in attention")
    parser.add_argument("--is_layernorm", action='store_true', help="Flag to specify if you want layer normalization")
  
  
    return parser.parse_args()

if __name__ == "__main__":
    torch.set_num_threads(6)
    args = parse_args()
  
    # wandb.init(project="transformer_pca", config={
    #     "learning_rate": args.lr,
    #     "batch_size": args.batch_size,
    #     "architecture": "gpt2",
    #     "dataset": "Generated Covariance Data",
    #     "training_steps": args.training_steps,
    #     "D": args.D,
    #     "N": args.N,
    #     "top_k_eigenvalues": args.k,
    #     "n_embd": args.n_embd,
    #     "n_layer": args.n_layer,
    #     "n_head": args.n_head,
    #     "input_is_covariance": args.input_is_cov,
    #     "predict_vector": args.predict_vector
    # })


    
    # Parameters
    D = args.D  # Dimension of each column vector
    # print("D:" D)
    N = args.N  # Number of columns in each X matrix
    k = args.k  # Number of top eigenvalues to use as labels
    n_embd = args.n_embd
    n_layer = args.n_layer
    n_head = args.n_head
    n_training_data = args.n_training_data
    # print(n_training_data)
    
    # training parameters
    input_is_cov = args.input_is_cov # True if input of transformer is covariance matrix, thus we can test whether transformer can do power iteration method
    print("input_is_cov: ",input_is_cov)
    predict_vector = args.predict_vector
    print("predict_vector: ",predict_vector)
    predict_cov = args.predict_cov
    print("predict_cov: ",predict_cov)
    batch_size = args.batch_size
    csv_file = args.csv_name
    plot_file = args.plot_name
    # training_steps = args.training_steps
    print_every = args.log_step
    lr = args.lr
    is_relu = args.is_relu
    is_layernorm = args.is_layernorm

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    training_steps = int(n_training_data / batch_size)
    wandb.init(project="transformer_pca", name=args.run_name,config={
    "learning_rate": lr,
    "batch_size": batch_size,
    "architecture": "gpt2",
    "dataset": "Generated Covariance Data",
    "training_steps": training_steps,
    "D": D,
    "N": N,
    "top_k_eigenvalues": k,
    "n_embd": n_embd,
    "n_layer": n_layer,
    "n_head": n_head,
    "input_is_covariance": input_is_cov,
    "predict_vector": predict_vector,
    "predict_cov":predict_cov,
    "is_relu":is_relu,
    "is_layernorm": is_layernorm,
})
  
  
  
  
  
    # define modelm, optimizer, and loss
    model = TransformerModel(D, N, N+10, n_embd=n_embd, n_layer=n_layer, n_head=n_head,input_is_cov=input_is_cov, predict_vector=predict_vector, predict_cov = predict_cov, is_relu = is_relu, is_layernorm=is_layernorm, k=k).to(device)
    print(f"model architecture:{model.name}")
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = MeanRelativeSquaredError()
    # scheduler = ExponentialLR(optimizer, gamma=0.9)

    

    train_losses = []
    steps = []
    # Training loop with validation every 1000 steps
    start_time = time.time()
    loss_sum = 0



    
    # Check if the file exists to write headers, otherwise create a new file with headers
    if not os.path.exists(csv_file):
        with open(csv_file, mode='w', newline='') as file:
            writer = csv.writer(file)
            # Write header
            writer.writerow(['Step', 'Training Loss', 'Elapsed Time'])


    dataset = CovarianceDataset(args.dataset, k=k,predict_vector=True)
    # dataset_2 = CovarianceDataset("dataset/multivariate_gaussian_dataset_D_30_2560000.npz", k=k,predict_vector=True)
    # Define DataLoader for batch processing
    # combined_dataset = ConcatDataset([dataset, dataset_2])
    torch.manual_seed(args.seed)
    # dataloader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

  
    training_steps = int(n_training_data / batch_size)
    print("training_steps", training_steps)
    print_every = int(1024 / batch_size)
    
    for step, (X_train_batch, Y_train_batch, *Y_vector_batch) in enumerate(dataloader):
        model.train()
        if step == 0:
            print(f"shape of Y_train_batch: {Y_train_batch.shape}")
        # Move the input data to the GPU
        X_train_batch = X_train_batch.to(device)
        Y_train_batch = Y_train_batch[:,:k].to(device)
        if predict_vector and len(Y_vector_batch) > 0:
            Y_vector_batch = Y_vector_batch[0].to(device)  # Move eigenvectors to GPU
            Y_vector_batch = torch.transpose(Y_vector_batch, 1,2).to(device)
        # print(f"shape of Y_vector_batch: {Y_vector_batch.shape}")
        if step == 0:
            print(f"shape of X_train_batch: {X_train_batch.shape}")
            print(f"shape of Y_train_batch: {Y_train_batch.shape}")
            if predict_vector:
                print(f"shape of Y_vector_batch: {Y_vector_batch.shape}")
        
        output = model(X_train_batch)
        if step == 0:
            print("shape of output", output.shape)
        # print(f"shape of output: {output.shape}")
        # if predict_cov:
        #     output = output[:, :D, :D]
        #     # print("shape of output", output.shape)
        if predict_vector:
            output = output.view(batch_size, k, D)
            if step == 0:
                print(f"shape of output: {output.shape}")
            loss = 1 - F.cosine_similarity(output, Y_vector_batch, dim=-1).mean()


      
            # loss = 1 - F.cosine_similarity(output, Y_vector_batch).mean()
        else:
            loss = criterion(output, Y_train_batch)
        loss_sum += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    

        if (step+1) % print_every == 0:
            elapsed_time = time.time() - start_time
            avg_loss = loss_sum / print_every
            train_losses.append(avg_loss)
            steps.append(step + 1)
            print()
            print(f"Step [{step+1}/{training_steps}], Training Loss: {avg_loss:.6f}, Elapsed Time: {elapsed_time:.2f}s")
            print()
            wandb.log({"training_loss": avg_loss, "step": step+1})
            if predict_vector:
              # pass
              print(f"True top-{k} eigenvectors for 0th data in a batch:{Y_vector_batch[0,:]}")
              print(f"Predicted top-{k} eigenvectors for 0th data in a batch:{output[0,:]}")
            elif predict_cov:
              print(f"True covariance matrix for first data in first batch:{Y_train_batch[0,:]}")
              print(f"Predicted covariance matrix for first data in first batch:{output[0,:]}")
            else:
              print(f"True top-{k} eigenvalues for 0th~5th data in a batch:{Y_train_batch[:1,:k]}")
              print(f"Predicted top-{k} eigenvalues for 0th~5th data in a batch:{output[:1,:k]}")
          
            # with open(csv_file, mode='a', newline='') as file:
            #     writer = csv.writer(file)
            #     writer.writerow([step + 1, avg_loss, elapsed_time])
              
            start_time = time.time() 
            loss_sum = 0
        # if (step+1) % 1000 == 0:
        # #     scheduler.step()
        # if step == 10000:
        #     break

    torch.save(model.state_dict(), args.save_model_to)
    
    model_eval = TransformerModel(D, N, N+10, n_embd=n_embd, n_layer=n_layer, n_head=n_head,input_is_cov=input_is_cov, predict_vector=predict_vector, predict_cov = predict_cov, is_relu = is_relu, is_layernorm=is_layernorm, k=k).to(device)
    model.load_state_dict(torch.load(args.save_model_to, weights_only=True))
    model.eval()
    print("evaluation start.")
    X_test_1280, Y_test_1280, Y_vector_test_1280 = generate_batch(1280, D, N, k, input_is_cov,device=device)
    if predict_vector:
        Y_vector_test_1280 = torch.transpose(Y_vector_test_1280, 1, 2)
        print("Y_vector_shape: ", Y_vector_test_1280.shape)
        eval_output = model(X_test_1280)
        eval_output = eval_output.view(1280, k, D)
        print(f"eval_output shape: {eval_output.shape}")
        for i in range(k):
            error = 1 - F.cosine_similarity(eval_output[:,i,:], Y_vector_test_1280[:,i,:],dim=-1).mean()
            wandb.log({f"eigenvector_testing_error_{i+1}": error})
    else:
        eval_output = model(X_test_1280)
        print(f"shape of eval_output: {eval_output.shape}")
        print(f"shape of Y_test_1280: {Y_test_1280.shape}")
        for i in range(k):
            error = criterion(eval_output[:,i], Y_test_1280[:,i])
            wandb.log({f"{i+1}-eigenvalue_testing_error": error})
        error_total = criterion(eval_output, Y_test_1280)
        wandb.log({f"total_error": error_total})
  # Plot the training loss graph
    plt.figure()
    plt.plot(steps, train_losses, label='Training Loss')
    plt.xlabel('Step')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()


 
    plt.figtext(0.55, 0.75, f"Model Parameters:", fontsize=12, ha="left", fontweight='bold')
    plt.figtext(0.55, 0.7, f"D: {D}", fontsize=10, ha="left")
    plt.figtext(0.55, 0.65, f"N: {N}", fontsize=10, ha="left")
    plt.figtext(0.55, 0.6, f"k: {k}", fontsize=10, ha="left")
    plt.figtext(0.55, 0.55, f"n_embd: {n_embd}", fontsize=10, ha="left")
    plt.figtext(0.55, 0.5, f"n_layer: {n_layer}", fontsize=10, ha="left")
    plt.figtext(0.55, 0.45, f"n_head: {n_head}", fontsize=10, ha="left")
    plt.figtext(0.55, 0.4, f"input_is_cove: {input_is_cov}", fontsize=10, ha="left")
    plt.figtext(0.55, 0.35, f"predict_vector: {predict_vector}", fontsize=10, ha="left")
    plt.figtext(0.55, 0.3, f"predict_cov: {predict_cov}", fontsize=10, ha="left")
    plt.figtext(0.55, 0.25, f"Batch Size: {batch_size}", fontsize=10, ha="left")


    plt.savefig(plot_file)
    plt.close()  
  
  
  
  
