import os
import csv
import shutil
import time
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torch.optim import Adam
import random
import numpy as np

from dataset.e_piano import create_epiano_datasets, compute_epiano_accuracy, EPianoDatasetSampler

from model.music_transformer import MusicTransformer
from model.loss import SmoothCrossEntropyLoss

from utilities.constants import *
from utilities.device import get_device, use_cuda
from utilities.lr_scheduling import LrStepTracker, get_lr
from utilities.argument_funcs import parse_train_args, print_train_args, write_model_params
from utilities.run_model import train_epoch, eval_model

from peft import LoraConfig, get_peft_model
from read_params import parse_params_from_file

CSV_HEADER = ["Epoch", "Learn rate", "Avg Train loss", "Train Accuracy", "Avg Eval loss", "Eval accuracy"]

# Baseline is an untrained epoch that we evaluate as a baseline loss and accuracy
BASELINE_EPOCH = -1

# main
def main():
    """
    ----------
    Author: Damon Gwinn
    ----------
    Entry point. Trains a model specified by command line arguments
    ----------
    """

    args = parse_train_args()
    print_train_args(args)

    if(args.force_cpu):
        use_cuda(False)
        print("WARNING: Forced CPU usage, expect model to perform slower")
        print("")

    os.makedirs(args.output_dir, exist_ok=True)

    if args.multi_LoRA == 0:
        ##### Output prep #####
        params_file = os.path.join(args.output_dir, "model_params.txt")
        write_model_params(args, params_file)

        weights_folder = os.path.join(args.output_dir, "weights")
        os.makedirs(weights_folder, exist_ok=True)

        results_folder = os.path.join(args.output_dir, "results")
        os.makedirs(results_folder, exist_ok=True)

        results_file = os.path.join(results_folder, "results.csv")
        best_loss_file = os.path.join(results_folder, "best_loss_weights.pickle")
        best_acc_file = os.path.join(results_folder, "best_acc_weights.pickle")
        best_text = os.path.join(results_folder, "best_epochs.txt")

        if args.train_seed != -1:
            print("set training seed to:", args.train_seed)
            random.seed(args.train_seed)
            np.random.seed(args.train_seed)
            torch.manual_seed(args.train_seed)

    ##### Tensorboard #####
    if(args.no_tensorboard):
        tensorboard_summary = None
    else:
        from torch.utils.tensorboard import SummaryWriter

        tensorboad_dir = os.path.join(args.output_dir, "tensorboard")
        tensorboard_summary = SummaryWriter(log_dir=tensorboad_dir)

    ##### Datasets #####
    train_dataset, _, _ = create_epiano_datasets(args.input_dir, args.max_sequence, full_version=args.fixed_crop)
    # train_dataset_for_eval, val_dataset, test_dataset = create_epiano_datasets(args.input_dir, args.max_sequence)
    
    train_dataset_for_eval = train_dataset
    val_dataset = train_dataset  # Use training set for evaluation
    test_dataset = train_dataset  # Use training set for evaluation
    print("train set size:", len(train_dataset))
    print("train_dataset_for_eval size:", len(train_dataset_for_eval))
    print("val set size:", len(val_dataset))
    print("test set size:", len(test_dataset))

    if args.multi_LoRA == 0:
        sampler = None
        if args.train_sample_ratio != -1:
            print("Use sampler for ratio =", args.train_sample_ratio)
            sampler = EPianoDatasetSampler(train_dataset, seed=args.train_sample_seed,
                                        ratio=args.train_sample_ratio, saving_root=args.output_dir,
                                        shuffle=args.train_sample_shuffle)

        # for training
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.n_workers, shuffle=(sampler is None), sampler=sampler)

        # for evaluation
        train_loader_for_eval = DataLoader(train_dataset_for_eval, batch_size=args.batch_size, sampler=sampler)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.n_workers, sampler=sampler)
        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.n_workers, sampler=sampler)


    # normal training (either original rpr or nn.linar rpr)
    if not args.LoRA_finetune:
        model = MusicTransformer(n_layers=args.n_layers, num_heads=args.num_heads,
                d_model=args.d_model, dim_feedforward=args.dim_feedforward, dropout=args.dropout,
                max_sequence=args.max_sequence, rpr=args.rpr, enable_new_ver=args.enable_new_ver).to(get_device())
    else:

        base_model_path = "../full-ensemble-ten-trained-model/checkpoint/"
        params_path = base_model_path + str(args.train_sample_seed) + "/model_params.txt"
        weight_path = base_model_path + str(args.train_sample_seed) + "/results/best_acc_weights.pickle"
        print("multi LoRA value: ", args.multi_LoRA)
        if args.multi_LoRA > 0:
            for i in range(args.multi_LoRA):
                # get the sub dir num
                # if total true ensemble = 10 (train sample seed: 0-9) and multi LoRA = 3
                # then the sub dir will range from 0 to 29
                
                sub_dir_num = args.train_sample_seed * args.multi_LoRA + i
                print("sub dir number: ", sub_dir_num)
                sub_output_dir = os.path.join(args.output_dir, str(sub_dir_num))
                os.makedirs(sub_output_dir, exist_ok=True)

                params_file = os.path.join(sub_output_dir, "model_params.txt")
                write_model_params(args, params_file)

                weights_folder = os.path.join(sub_output_dir, "weights")
                os.makedirs(weights_folder, exist_ok=True)

                results_folder = os.path.join(sub_output_dir, "results")
                os.makedirs(results_folder, exist_ok=True)

                results_file = os.path.join(results_folder, "results.csv")
                best_loss_file = os.path.join(results_folder, "best_loss_weights.pickle")
                best_acc_file = os.path.join(results_folder, "best_acc_weights.pickle")
                best_text = os.path.join(results_folder, "best_epochs.txt")

                # originally 
                # let each fine-tuning models get diffeent data
                
                sampler = None
                if not args.random_LoRA: 
                    if args.train_sample_ratio != -1:
                        print("Use sampler for ratio =", args.train_sample_ratio)
                        sampler = EPianoDatasetSampler(train_dataset, seed=args.train_sample_seed,
                                                    ratio=args.train_sample_ratio, saving_root=sub_output_dir,
                                                    shuffle=args.train_sample_shuffle)

                else:
                    # generate different seeds for each finetuning models
                    random_LoRA_seed = 52876 + args.train_sample_seed + i
                    if args.train_sample_ratio != -1:
                        print("Use sampler for ratio =", args.train_sample_ratio)
                        # use random seed when extracting the training dataset
                        sampler = EPianoDatasetSampler(train_dataset, seed=random_LoRA_seed,
                                                    ratio=args.train_sample_ratio, saving_root=sub_output_dir,
                                                    shuffle=args.train_sample_shuffle)

                # for training
                train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.n_workers, shuffle=(sampler is None), sampler=sampler)

                # for evaluation
                train_loader_for_eval = DataLoader(train_dataset_for_eval, batch_size=args.batch_size, sampler=sampler)
                val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.n_workers, sampler=sampler)
                test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.n_workers, sampler=sampler)


                params = parse_params_from_file(params_path)
                print("base model params: ", params)

                full_model = MusicTransformer(
                                    n_layers = params['n_layers'],
                                    num_heads = params['num_heads'],
                                    d_model = params['d_model'],
                                    dim_feedforward = params['dim_feedforward'],
                                    dropout = params['dropout'],
                                    max_sequence = params['max_sequence'],
                                    rpr = params['rpr'],
                                    enable_new_ver = params['enable_new_ver']
                                    ).to(get_device())

                full_model.load_state_dict(torch.load(weight_path))
                # apply LoRA on all linear layer in self attn module
                config = LoraConfig(
                    r=8,
                    lora_alpha=8,
                    target_modules=["Wq", "Wv"],
                    lora_dropout=0.0,
                    bias="lora_only"
                )

                model = get_peft_model(full_model, config)
                model.print_trainable_parameters()


                ##### Continuing from previous training session #####
                start_epoch = BASELINE_EPOCH
                if(args.continue_weights is not None):
                    if(args.continue_epoch is None):
                        print("ERROR: Need epoch number to continue from (-continue_epoch) when using continue_weights")
                        return
                    else:
                        model.load_state_dict(torch.load(args.continue_weights))
                        start_epoch = args.continue_epoch
                elif(args.continue_epoch is not None):
                    print("ERROR: Need continue weights (-continue_weights) when using continue_epoch")
                    return

                ##### Lr Scheduler vs static lr #####
                if(args.lr is None):
                    if(args.continue_epoch is None):
                        init_step = 0
                    else:
                        init_step = args.continue_epoch * len(train_loader)

                    lr = LR_DEFAULT_START
                    lr_stepper = LrStepTracker(args.d_model, SCHEDULER_WARMUP_STEPS, init_step)
                else:
                    lr = args.lr

                ##### Not smoothing evaluation loss #####
                eval_loss_func = nn.CrossEntropyLoss(ignore_index=TOKEN_PAD)

                ##### SmoothCrossEntropyLoss or CrossEntropyLoss for training #####
                if(args.ce_smoothing is None):
                    train_loss_func = eval_loss_func
                else:
                    train_loss_func = SmoothCrossEntropyLoss(args.ce_smoothing, VOCAB_SIZE, ignore_index=TOKEN_PAD)

                ##### Optimizer #####
                opt = Adam(model.parameters(), lr=lr, betas=(ADAM_BETA_1, ADAM_BETA_2), eps=ADAM_EPSILON)

                if(args.lr is None):
                    lr_scheduler = LambdaLR(opt, lr_stepper.step)
                else:
                    lr_scheduler = None

                ##### Tracking best evaluation accuracy #####
                best_eval_acc        = 0.0
                best_eval_acc_epoch  = -1
                best_eval_loss       = float("inf")
                best_eval_loss_epoch = -1

                ##### Results reporting #####
                if(not os.path.isfile(results_file)):
                    with open(results_file, "w", newline="") as o_stream:
                        writer = csv.writer(o_stream)
                        writer.writerow(CSV_HEADER)


                ##### TRAIN LOOP #####
                for epoch in range(start_epoch, args.epochs):
                    # Baseline has no training and acts as a base loss and accuracy (epoch 0 in a sense)
                    if(epoch > BASELINE_EPOCH):
                        print(SEPERATOR)
                        print("NEW EPOCH:", epoch+1)
                        print(SEPERATOR)
                        print("")

                        # Train
                        train_epoch(epoch+1, model, train_loader, train_loss_func, opt, lr_scheduler, args.print_modulus)

                        print(SEPERATOR)
                        print("Evaluating:")
                    else:
                        print(SEPERATOR)
                        print("Baseline model evaluation (Epoch 0):")

                    # Eval
                    pre_eval = time.time()
                    train_loss, train_acc = eval_model(model, train_loader_for_eval, train_loss_func)
                    eval_loss, eval_acc = eval_model(model, test_loader, eval_loss_func)
                    post_eval = time.time()
                    # print("per epoch eval time: ", post_eval - pre_eval)
                    # Learn rate
                    lr = get_lr(opt)

                    print("Epoch:", epoch+1)
                    print("Avg train loss:", train_loss)
                    print("Avg train acc:", train_acc)
                    print("Avg eval loss:", eval_loss)
                    print("Avg eval acc:", eval_acc)
                    print(SEPERATOR)
                    print("")

                    new_best = False

                    if(eval_acc > best_eval_acc):
                        best_eval_acc = eval_acc
                        best_eval_acc_epoch  = epoch+1
                        torch.save(model.state_dict(), best_acc_file)
                        new_best = True

                    if(eval_loss < best_eval_loss):
                        best_eval_loss       = eval_loss
                        best_eval_loss_epoch = epoch+1
                        torch.save(model.state_dict(), best_loss_file)
                        new_best = True

                    # Writing out new bests
                    if(new_best):
                        with open(best_text, "w") as o_stream:
                            print("Best eval acc epoch:", best_eval_acc_epoch, file=o_stream)
                            print("Best eval acc:", best_eval_acc, file=o_stream)
                            print("")
                            print("Best eval loss epoch:", best_eval_loss_epoch, file=o_stream)
                            print("Best eval loss:", best_eval_loss, file=o_stream)


                    if(not args.no_tensorboard):
                        tensorboard_summary.add_scalar("Avg_CE_loss/train", train_loss, global_step=epoch+1)
                        tensorboard_summary.add_scalar("Avg_CE_loss/eval", eval_loss, global_step=epoch+1)
                        tensorboard_summary.add_scalar("Accuracy/train", train_acc, global_step=epoch+1)
                        tensorboard_summary.add_scalar("Accuracy/eval", eval_acc, global_step=epoch+1)
                        tensorboard_summary.add_scalar("Learn_rate/train", lr, global_step=epoch+1)
                        tensorboard_summary.flush()

                    if((epoch+1) % args.weight_modulus == 0):
                        epoch_str = str(epoch+1).zfill(PREPEND_ZEROS_WIDTH)
                        path = os.path.join(weights_folder, "epoch_" + epoch_str + ".pickle")
                        torch.save(model.state_dict(), path)

                    with open(results_file, "a", newline="") as o_stream:
                        writer = csv.writer(o_stream)
                        writer.writerow([epoch+1, lr, train_loss, train_acc, eval_loss, eval_acc])

                # Sanity check just to make sure everything is gone
                if(not args.no_tensorboard):
                    tensorboard_summary.flush()

            return


        # if no multi LoRA, the function should do this and enter the normal train loop
        params = parse_params_from_file(params_path)
        print("base model params: ", params)

        full_model = MusicTransformer(
                            n_layers = params['n_layers'],
                            num_heads = params['num_heads'],
                            d_model = params['d_model'],
                            dim_feedforward = params['dim_feedforward'],
                            dropout = params['dropout'],
                            max_sequence = params['max_sequence'],
                            rpr = params['rpr'],
                            enable_new_ver = params['enable_new_ver']
                            ).to(get_device())

        full_model.load_state_dict(torch.load(weight_path))
        # apply LoRA on all linear layer in self attn module
        config = LoraConfig(
            r=8,
            lora_alpha=8,
            target_modules=["Wq", "Wv"],
            lora_dropout=0.1,
            bias="lora_only"
        )

        model = get_peft_model(full_model, config)
        model.print_trainable_parameters()

    ##### Continuing from previous training session #####
    start_epoch = BASELINE_EPOCH
    if(args.continue_weights is not None):
        if(args.continue_epoch is None):
            print("ERROR: Need epoch number to continue from (-continue_epoch) when using continue_weights")
            return
        else:
            model.load_state_dict(torch.load(args.continue_weights))
            start_epoch = args.continue_epoch
    elif(args.continue_epoch is not None):
        print("ERROR: Need continue weights (-continue_weights) when using continue_epoch")
        return

    ##### Lr Scheduler vs static lr #####
    if(args.lr is None):
        if(args.continue_epoch is None):
            init_step = 0
        else:
            init_step = args.continue_epoch * len(train_loader)

        lr = LR_DEFAULT_START
        lr_stepper = LrStepTracker(args.d_model, SCHEDULER_WARMUP_STEPS, init_step)
    else:
        lr = args.lr

    ##### Not smoothing evaluation loss #####
    eval_loss_func = nn.CrossEntropyLoss(ignore_index=TOKEN_PAD)

    ##### SmoothCrossEntropyLoss or CrossEntropyLoss for training #####
    if(args.ce_smoothing is None):
        train_loss_func = eval_loss_func
    else:
        train_loss_func = SmoothCrossEntropyLoss(args.ce_smoothing, VOCAB_SIZE, ignore_index=TOKEN_PAD)

    ##### Optimizer #####
    opt = Adam(model.parameters(), lr=lr, betas=(ADAM_BETA_1, ADAM_BETA_2), eps=ADAM_EPSILON)

    if(args.lr is None):
        lr_scheduler = LambdaLR(opt, lr_stepper.step)
    else:
        lr_scheduler = None

    ##### Tracking best evaluation accuracy #####
    best_eval_acc        = 0.0
    best_eval_acc_epoch  = -1
    best_eval_loss       = float("inf")
    best_eval_loss_epoch = -1

    ##### Results reporting #####
    if(not os.path.isfile(results_file)):
        with open(results_file, "w", newline="") as o_stream:
            writer = csv.writer(o_stream)
            writer.writerow(CSV_HEADER)


    ##### TRAIN LOOP #####
    for epoch in range(start_epoch, args.epochs):
        # Baseline has no training and acts as a base loss and accuracy (epoch 0 in a sense)
        if(epoch > BASELINE_EPOCH):
            print(SEPERATOR)
            print("NEW EPOCH:", epoch+1)
            print(SEPERATOR)
            print("")

            # Train
            train_epoch(epoch+1, model, train_loader, train_loss_func, opt, lr_scheduler, args.print_modulus)

            print(SEPERATOR)
            print("Evaluating:")
        else:
            print(SEPERATOR)
            print("Baseline model evaluation (Epoch 0):")

        # Eval
        train_loss, train_acc = eval_model(model, train_loader_for_eval, train_loss_func)
        eval_loss, eval_acc = eval_model(model, test_loader, eval_loss_func)

        # Learn rate
        lr = get_lr(opt)

        print("Epoch:", epoch+1)
        print("Avg train loss:", train_loss)
        print("Avg train acc:", train_acc)
        print("Avg eval loss:", eval_loss)
        print("Avg eval acc:", eval_acc)
        print(SEPERATOR)
        print("")

        new_best = False

        if(eval_acc > best_eval_acc):
            best_eval_acc = eval_acc
            best_eval_acc_epoch  = epoch+1
            torch.save(model.state_dict(), best_acc_file)
            new_best = True

        if(eval_loss < best_eval_loss):
            best_eval_loss       = eval_loss
            best_eval_loss_epoch = epoch+1
            torch.save(model.state_dict(), best_loss_file)
            new_best = True

        # Writing out new bests
        if(new_best):
            with open(best_text, "w") as o_stream:
                print("Best eval acc epoch:", best_eval_acc_epoch, file=o_stream)
                print("Best eval acc:", best_eval_acc, file=o_stream)
                print("")
                print("Best eval loss epoch:", best_eval_loss_epoch, file=o_stream)
                print("Best eval loss:", best_eval_loss, file=o_stream)


        if(not args.no_tensorboard):
            tensorboard_summary.add_scalar("Avg_CE_loss/train", train_loss, global_step=epoch+1)
            tensorboard_summary.add_scalar("Avg_CE_loss/eval", eval_loss, global_step=epoch+1)
            tensorboard_summary.add_scalar("Accuracy/train", train_acc, global_step=epoch+1)
            tensorboard_summary.add_scalar("Accuracy/eval", eval_acc, global_step=epoch+1)
            tensorboard_summary.add_scalar("Learn_rate/train", lr, global_step=epoch+1)
            tensorboard_summary.flush()

        if((epoch+1) % args.weight_modulus == 0):
            epoch_str = str(epoch+1).zfill(PREPEND_ZEROS_WIDTH)
            path = os.path.join(weights_folder, "epoch_" + epoch_str + ".pickle")
            torch.save(model.state_dict(), path)

        with open(results_file, "a", newline="") as o_stream:
            writer = csv.writer(o_stream)
            writer.writerow([epoch+1, lr, train_loss, train_acc, eval_loss, eval_acc])

    # Sanity check just to make sure everything is gone
    if(not args.no_tensorboard):
        tensorboard_summary.flush()

    return


if __name__ == "__main__":
    main()
