import sys
sys.path.append("..")

from droptrak.output_function import BaseModelOutputClass
from droptrak.droptrak import DropTRAK
import torch.nn as nn
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

import torch
import torch.nn as nn
import os
import random

import sys
sys.path.append("./MusicTransformer_Pytorch")

from argument_funcs import parse_trak_args, print_trak_args
from model.music_transformer import MusicTransformer
from dataset.e_piano import create_epiano_datasets, EPianoDatasetSampler

from utilities.constants import *

from tqdm import tqdm
import numpy as np

import time
import argparse

from peft import LoraConfig, get_peft_model


class MusicTransformerModelOutput(BaseModelOutputClass):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    start_idx = 0  # This is only used for testset, states the start idx(along gen len) to calculate the P
    last_k_tokens = 1
    num_prime = 256
    cherrypick_token = -1  # start from 1

    def __init__(self):
        super().__init__(self)

    @staticmethod
    def model_output(data, model, *args, **kwargs):
        # y = model(x)
        x, tgt = data
        x, tgt = x.to(MusicTransformerModelOutput.device), tgt.to(MusicTransformerModelOutput.device)
        loss_fn = nn.CrossEntropyLoss(ignore_index=TOKEN_PAD, reduce=False)
        y = None
        for i in range(MusicTransformerModelOutput.last_k_tokens):
            if MusicTransformerModelOutput.cherrypick_token > 0:
                if i != MusicTransformerModelOutput.cherrypick_token - 1:
                    continue
            y_iter = model(x[:, i:MusicTransformerModelOutput.num_prime+i])
            y_iter = y_iter.permute(0, 2, 1)
            y_iter = y_iter[:, :, -1:]
            if y is None:
                y = y_iter
            else:
                y = torch.cat((y, y_iter), axis=2)
        if MusicTransformerModelOutput.cherrypick_token < 0:
            y = y[:, :, -MusicTransformerModelOutput.last_k_tokens:]
            tgt = tgt[:, -MusicTransformerModelOutput.last_k_tokens:]
        else:
            y = y[:, :, -1:]
            tgt = tgt[:, MusicTransformerModelOutput.cherrypick_token+MusicTransformerModelOutput.num_prime-2:MusicTransformerModelOutput.cherrypick_token+MusicTransformerModelOutput.num_prime-1]

        # calculate logp for each token (b, length)
        logp_elements = -loss_fn.forward(y, tgt)[:, MusicTransformerModelOutput.start_idx:]  # -1 added for all

        logp = torch.sum(logp_elements, dim=1)
        return logp - torch.log(1 - torch.exp(logp))

    @staticmethod
    def get_out_to_loss_grad(data, model, *args, **kwargs):
        x, tgt = data
        x, tgt = x.to(MusicTransformerModelOutput.device), tgt.to(MusicTransformerModelOutput.device)
        # y = model(x)
        loss_fn = nn.CrossEntropyLoss(ignore_index=TOKEN_PAD, reduce=False)
        y = None
        for i in range(MusicTransformerModelOutput.last_k_tokens):
            if MusicTransformerModelOutput.cherrypick_token > 0:
                if i != MusicTransformerModelOutput.cherrypick_token - 1:
                    continue
            y_iter = model(x[:, i:MusicTransformerModelOutput.num_prime+i])
            y_iter = y_iter.permute(0, 2, 1)
            y_iter = y_iter[:, :, -1:]
            if y is None:
                y = y_iter
            else:
                y = torch.cat((y, y_iter), axis=2)

        if MusicTransformerModelOutput.cherrypick_token < 0:
            y = y[:, :, -MusicTransformerModelOutput.last_k_tokens:]
            tgt = tgt[:, -MusicTransformerModelOutput.last_k_tokens:]
        else:
            y = y[:, :, -1:]
            tgt = tgt[:, MusicTransformerModelOutput.cherrypick_token+MusicTransformerModelOutput.num_prime-2:MusicTransformerModelOutput.cherrypick_token+MusicTransformerModelOutput.num_prime-1]

        p = torch.exp(torch.sum(-loss_fn.forward(y, tgt)[:, MusicTransformerModelOutput.start_idx:], dim=1))
        res = (1 - p).clone().detach().unsqueeze(-1)

        return res

if __name__ == "__main__":
    enter_func = time.time()
    # First, check if CUDA is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    args = parse_trak_args()

    
    model = MusicTransformer(n_layers=args.n_layers, num_heads=args.num_heads,
                d_model=args.d_model, dim_feedforward=args.dim_feedforward,
                max_sequence=args.max_sequence, rpr=args.rpr, dropout=0.1,
                enable_new_ver=args.enable_new_ver).to(device)
    if args.LoRA_finetune:
        config = LoraConfig(
            r=8,
            lora_alpha=8,
            target_modules=["Wq", "Wv"],
            lora_dropout=0,
            bias="lora_only"
        )

        # define the LoRA-augmented MusicTransformer
        peft_model = get_peft_model(model, config)

    # generate the training set
    train_dataset, _, _ = create_epiano_datasets(args.midi_root,
                                                 args.train_windows_size,
                                                 random_seq=False,  # we use sliding window in order
                                                 full_version=True) # we use sliding window on the whole sequence w/o overlapping

    MusicTransformerModelOutput.last_k_tokens = 1  # for training, we only need the last token to be eval
    MusicTransformerModelOutput.num_prime = args.num_prime
    MusicTransformerModelOutput.cherrypick_token = args.trak_cherrypick  # if not cherrypick, then it's -1
    print("MusicTransformerModelOutput.cherrypick_token", MusicTransformerModelOutput.cherrypick_token) 

    sampler = EPianoDatasetSampler(train_dataset, seed=0, ratio=1,
                                   saving_root="./tmp", shuffle=False)

    # Initialize the model, loss function, and optimizer
    
    checkpoint_files = []
    if args.multi_LoRA  == 0:
        for independent in range(args.independent):
            checkpoint_files += [
                os.path.join(args.trak_model_checkpoints_root, f"{independent}/" + args.trak_model_checkpoints_extension) for _ in range(args.ensemble)
            ]
    else:
        max_multi = 25

        for independent in range(args.independent):
            start_num = independent * 25
            for LoRA_num in range(args.multi_LoRA):
                real_num = start_num + LoRA_num
                checkpoint_files += [
                    os.path.join(args.trak_model_checkpoints_root, 
                                f"{real_num}/" + args.trak_model_checkpoints_extension) 
            ]

    print(checkpoint_files)

    train_loader = DataLoader(train_dataset, batch_size=1, num_workers=1, sampler=sampler)
    test_dataset = create_epiano_datasets(args.trak_generated_music_root,
                                          args.test_windows_size,
                                          random_seq=False,  # we use sliding window in order
                                          full_version=True, # we use sliding window on the whole sequence w/o overlapping
                                          split=False) 
    test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1)

    trak_start = time.time()
    print("preprocess time: ", trak_start - enter_func)

    if args.LoRA_finetune:
        trak = DropTRAK(model=peft_model,
                model_checkpoints=checkpoint_files,
                train_loader=train_loader,
                test_loader=test_loader,
                model_output_class=MusicTransformerModelOutput,
                device=device,
                independent_num=args.independent,
                dropout=args.dropout,
                LoRA_finetune=args.LoRA_finetune,
                LoRA_grad_only=args.LoRA_grad_only
                )
    
    else:
        trak = DropTRAK(model=model,
                model_checkpoints=checkpoint_files,
                train_loader=train_loader,
                test_loader=test_loader,
                model_output_class=MusicTransformerModelOutput,
                device=device,
                independent_num=args.independent,
                dropout=args.dropout,
                LoRA_finetune=args.LoRA_finetune,
                LoRA_grad_only=args.LoRA_grad_only
                )

    trak_finish_build = time.time()
    print("build trak time: ", trak_finish_build-trak_start)

    torch.cuda.reset_peak_memory_stats("cuda")
    score = trak.score()
    # score = trak.q_drop_score()
    peak_memory = torch.cuda.max_memory_allocated("cuda") / 1e6  # Convert to MB
    print(f"Peak memory usage: {peak_memory} MB")
    
    trak_finish_score = time.time()
    print("trak score time: ", trak_finish_score-trak_finish_build)


    torch.save(score, f"../time-mem-score/score_music{'_dropout' if args.dropout else ''}_ensemble_{args.multi_LoRA}_independent_{args.independent}_retry.pt")