from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
import comet_ml
import json
import numpy as np
import os
import pdb
import random
import time
import torch
from model import get_model_class
from time import gmtime, strftime
from data_process.data_factory import data_provider
from data_process.data_factory import data_dict




llm_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-1.5B-Instruct",
    torch_dtype="auto",
    device_map="auto"
)

word_embeds = llm_model.get_input_embeddings().weight.to(llm_model.dtype)   # (batch, prompt_token, dim)   # 32 * 129 * 768
    

    
def main(device, config, save_dir, logger, data_init_loc, args):
    # Create/overwrite checkpoints folder and results folder
    if os.path.exists(os.path.join(save_dir, 'checkpoints')):
        print('Checkpoint Directory Already Exists - if continue will overwrite files inside. Press c to continue.')
        # pdb.set_trace()
    else:
        os.makedirs(os.path.join(save_dir, 'checkpoints'))

    # logger.log_parameters(config)

    # Run start training
    vqvae_config, summary = start_training(device=device, vqvae_config=config['vqvae_config'], save_dir=save_dir,
                                           logger=logger, data_init_loc=data_init_loc, args=args)

    # Save config file
    config['vqvae_config'] = vqvae_config
    print('CONFIG FILE TO SAVE:', config)

    # Create Configs folder (e.g. plots, samples, etc.)
    if os.path.exists(os.path.join(save_dir, 'configs')):
        print('Saved Config Directory Already Exists - if continue will overwrite files inside. Press c to continue.')
        # pdb.set_trace()
    else:
        os.makedirs(os.path.join(save_dir, 'configs'))

    # Save the json copy
    with open(os.path.join(save_dir, 'configs', 'config_file.json'), 'w+') as f:
        json.dump(config, f, indent=4)

    # Save the Master File
    summary['log_path'] = os.path.join(save_dir)
    master['summaries'] = summary
    print('MASTER FILE:', master)
    with open(os.path.join(save_dir, 'master.json'), 'w') as f:
        json.dump(master, f, indent=4)


def start_training(device, vqvae_config, save_dir, logger, data_init_loc, args):
    # Create summary dictionary
    summary = {}

    # Sample and fix a random seed if not set
    if 'general_seed' not in vqvae_config:
        vqvae_config['seed'] = random.randint(0, 9999)

    general_seed = vqvae_config['general_seed']
    summary['general_seed'] = general_seed
    torch.manual_seed(general_seed)
    random.seed(general_seed)
    np.random.seed(general_seed)
    # if use another random library need to set that seed here too

    torch.backends.cudnn.deterministic = True

    summary['data initialization location'] = data_init_loc
    summary['device'] = device  # add the cpu/gpu to the summary

    # Setup model
    model_class = get_model_class(vqvae_config['model_name'].lower())
    model = model_class(vqvae_config, word_embeds)  # Initialize model

    print('Total # trainable parameters: ', sum(p.numel() for p in model.parameters() if p.requires_grad))

    if vqvae_config['pretrained']:
        # pretrained needs to be the path to the trained model if you want it to load
        model = torch.load(vqvae_config['pretrained'])  # Get saved pytorch model.
    summary['vqvae_config'] = vqvae_config  # add the model information to the summary

    # Start training the model
    start_time = time.time()
    model = train_model(model, device, vqvae_config, save_dir, logger, args=args)

    # Once the model has trained - Save full pytorch model
    torch.save(model, os.path.join(save_dir, 'checkpoints/final_model.pth'))

    # Save and return
    summary['total_time'] = round(time.time() - start_time, 3)
    return vqvae_config, summary


def train_model(model, device, vqvae_config, save_dir, logger, args):
    # Set the optimizer
    optimizer = model.configure_optimizers(lr=vqvae_config['learning_rate'])

    # Setup model (send to device, set to train)
    model.to(device)
    start_time = time.time()

    print('BATCHSIZE:', args.batch_size)
    train_loader, vali_loader, test_loader = create_datloaders(batchsize=args.batch_size, dataset=vqvae_config["dataset"], base_path=args.base_path, args=args)

    val_loss_max = np.inf
    for epoch in range(int(vqvae_config['num_training_updates'])):
        train_all_loss = []
        train_recon_loss = []
        train_vq_loss = []

        val_all_loss = []
        val_recon_loss = []
        val_vq_loss = []

        model.train()
        for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
            #print(batch_x)
            #print(batch_x.size())
            try:
                bs, n_vars = batch_x.size(0),batch_x.size(2)
                batch_x_i = torch.reshape(torch.transpose(batch_x, 1, 2), (bs*n_vars, args.seq_len, -1))
                # batch_y = torch.reshape(torch.transpose(batch_y, 1, 2), (bs*n_vars, args.pred_len, -1))
            except:
                batch_x_i = batch_x.unsqueeze(2)
            tensor_all_data_in_batch = torch.tensor(batch_x_i, dtype=torch.float, device=device)

            loss, vq_loss, recon_error, x_recon, perplexity, embedding_weight, encoding_indices, encodings, _ = \
                model.shared_eval(tensor_all_data_in_batch, optimizer, 'train', comet_logger=logger)
            train_all_loss.append(loss.item())
            train_recon_loss.append(recon_error.item())
            train_vq_loss.append(vq_loss.item())

            if i % 10 == 0:
                log = 'Iter: {:03d}, Train ALL_Loss: {:.4f} Recon_Loss: {:.4f} VQ_Loss: {:.4f}'
                print(log.format(i, train_all_loss[-1], train_recon_loss[-1], train_vq_loss[-1]), flush=True)

        mean_train_all_loss = np.mean(train_all_loss)
        mean_train_recon_loss = np.mean(train_recon_loss)
        men_train_vq_loss = np.mean(train_vq_loss)

        # uncomment if you want the validation loss
        model.eval()
        for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
            try:
                bs, n_vars = batch_x.size(0),batch_x.size(2)
                batch_x_i = torch.reshape(torch.transpose(batch_x, 1, 2), (bs*n_vars, args.seq_len, -1))
                # batch_y = torch.reshape(torch.transpose(batch_y, 1, 2), (bs*n_vars, args.pred_len, -1))
            except:
                batch_x_i = batch_x.unsqueeze(2)
            tensor_all_data_in_batch = torch.tensor(batch_x_i, dtype=torch.float, device=device)
            loss, vq_loss, recon_error, x_recon, perplexity, embedding_weight, \
                val_encoding_indices, val_encodings, _ = \
                model.shared_eval(tensor_all_data_in_batch, optimizer, 'val', comet_logger=logger)

            val_all_loss.append(loss.item())
            val_recon_loss.append(recon_error.item())
            val_vq_loss.append(vq_loss.item())

        mean_val_all_loss = np.mean(val_all_loss)
        mean_val_recon_loss = np.mean(val_recon_loss)
        men_val_vq_loss = np.mean(val_vq_loss)

        log = 'Epoch: {:03d}, Train ALL Loss: {:.4f}, Train_Recon_Loss: {:.4f} Valid ALL Loss: {:.4f} Valid_Recon_Loss: {:.4f}'
        print(log.format(epoch, mean_train_all_loss, mean_train_recon_loss, mean_val_all_loss, mean_val_recon_loss), flush=True)

        if mean_val_all_loss <= val_loss_max:
            val_loss_max = mean_val_all_loss
            dataset = vqvae_config['dataset']
            torch.save(model, os.path.join(save_dir, f'checkpoints/best_vqvae_{dataset}_model.pth'))
            print('Better Validation Loss, Model Saved!')


        if epoch % 10 == 0:
            # save the model checkpoints locally and to comet
            torch.save(model, os.path.join(save_dir, f'checkpoints/model_epoch_{epoch}.pth'))
            print('Saved model from epoch ', epoch)

    print('total time: ', round(time.time() - start_time, 3))
    return model


def create_datloaders(batchsize=100, dataset="dummy", base_path='dummy', args=None):


    if dataset == 'weather':
        print('weather')
        # full_path = base_path + '/weather'
        full_path = base_path + '/weather.csv'


    elif dataset == 'electricity':
        print('electricity')
        # full_path = base_path + '/electricity'
        full_path = base_path + '/electricity.csv'


    elif dataset == 'traffic':
        print('traffic')
        # full_path = base_path + '/traffic'
        full_path = base_path + '/traffic.csv'


    elif dataset == 'ETTh1':
        print('ETTh1')
        # full_path = base_path + '/ETTh1'
        full_path = base_path + '/ETTh1.csv'


    elif dataset == 'ETTm1':
        print('ETTm1')
        # full_path = base_path + '/ETTm1'
        full_path = base_path + '/ETTm1.csv'


    elif dataset == 'ETTh2':
        print('ETTh2')
        # full_path = base_path + '/ETTh2'
        full_path = base_path + '/ETTh2.csv'


    elif dataset == 'ETTm2':
        print('ETTm2')
        # full_path = base_path + '/ETTm2'
        full_path = base_path + '/ETTm2.csv'


    elif dataset == 'all':
        print('all')
        full_path = base_path + '/all'

    else:
        print('Not done yet')
        # pdb.set_trace()

    train_data, train_loader = data_provider(args, flag='train')
    val_data, val_loader = data_provider(args, flag='val')
    test_data, test_loader = data_provider(args, flag='test')


    return train_loader, val_loader, test_loader


if __name__ == '__main__':
    # create argument parser to read in from the python terminal call
    parser = argparse.ArgumentParser()

    parser.add_argument('--config_path', type=str,
                        required=False, default='./scripts/fast_ETTh1_vq3_qw2.json',
                        help='path to specific config file once already in the config folder')
    

    parser.add_argument('--model_init_num_gpus', type=int,
                        required=False, default=0,
                        help='number of gpus to use, 0 indexed, so if you want 1 gpu say 0')
    parser.add_argument('--data_init_cpu_or_gpu', type=str,
                        required=False, default='cpu',
                        help='the data initialization location')
    parser.add_argument('--comet_log', action='store_true',
                        required=False,
                        help='whether to log to comet online')
    parser.add_argument('--comet_tag', type=str,
                        required=False, default='pipeline',
                        help='the experimental tag to add to comet - this should be the person running the exp')
    parser.add_argument('--comet_name', type=str,
                        required=False, default='vqvae_all',
                        help='the experiment name to add to comet')
    parser.add_argument('--save_path', type=str,
                        required=False, default='./model_save/paper_all_MM/',
                        help='where were going to save the checkpoints')
    
    parser.add_argument('--base_path', type=str,
                        default='./data', help='saved revin data to train model(original data)')
    parser.add_argument('--batch_size', type=int,
                        required=False, default=64,
                        help='batch_size')


    parser.add_argument('--data', type=str, required=False, default='custom', help='dataset type')
    

    parser.add_argument('--root_path', type=str, default='root', help='root path of the data file')
    parser.add_argument('--data_path', type=str, default='data', help='data file')
    


    parser.add_argument('--features', type=str, default='M',
                        help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
    parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
    parser.add_argument('--freq', type=str, default='h',
                        help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')

    parser.add_argument('--seq_len', type=int, default=96, help='input sequence length')
    parser.add_argument('--label_len', type=int, default=0, help='start token length')
    parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')

    # Formers
    parser.add_argument('--enc_in', type=int, default=1,
                        help='encoder input size')  # DLinear with --individual, use this hyperparameter as the number of channels
    parser.add_argument('--embed', type=str, default=None,
                        help='time features encoding, options:[timeF, fixed, learned]')

    args = parser.parse_args()

    # Get config file
    config_file = args.config_path
    print('Config folder:\t {}'.format(config_file))

    # Load JSON config file
    with open(config_file, 'r') as f:
        config = json.load(f)
    print('Running Config:', config_file)

    # Make save directory --> will be identically named to config structure
    save_folder_name = 'CD' + str(config['vqvae_config']['embedding_dim']) + '_CW' + str(
        config['vqvae_config']['num_embeddings']) + '_CF' + str(
        config['vqvae_config']['compression_factor']) + '_BS' + str(args.batch_size) + '_ITR' + str(
        config['vqvae_config']['num_training_updates']) + '_seq_len' + str(args.seq_len) + '_pred_len'+ str(args.pred_len)+'all_vq3_1_outof_100_final'

    save_dir = args.save_path + save_folder_name


    master = {
        'start_time': strftime("%Y-%m-%dT%H-%M-%S", gmtime()),
        'config file': config_file,
        'save directory': save_dir,
        'gpus': args.model_init_num_gpus,
    }

    # Setting up the comet logger
    if args.comet_log:
        # Create an experiment with your api key
        comet_logger = comet_ml.Experiment(
            api_key=config['comet_config']['api_key'],
            project_name=config['comet_config']['project_name'],
            workspace=config['comet_config']['workspace'],
        )
        comet_logger.add_tag(args.comet_tag)
        comet_logger.set_name(args.comet_name)
    else:
        print('PROBLEM: not saving to comet')
        comet_logger = None
        # pdb.set_trace()

    # Set up GPU / CPU
    if torch.cuda.is_available() and args.model_init_num_gpus >= 0:
        assert args.model_init_num_gpus < torch.cuda.device_count()  # sanity check
        device = 'cuda:{:d}'.format(args.model_init_num_gpus)
    else:
        device = 'cpu'
    # device = 'cpu'
    # Where to init data for training (cpu or gpu) -->  will be trained wherever args.model_init_num_gpus says
    if args.data_init_cpu_or_gpu == 'gpu':
        data_init_loc = device
    else:
        data_init_loc = 'cpu'

    for root, _, files in os.walk(args.root_path):
        for file in files:
            if file.endswith(".csv"):
                file_path = os.path.join(root, file)
                print(f"file path:{file_path}")
                csv_file = file_path.split('.')[0].split('/')[-1]
                print('file name', csv_file)
                save_diri = save_dir + '/'+csv_file
                print('save_dir',save_dir)
                if csv_file not in data_dict.keys():
                    args.data = 'custom'
                    args.root_path = '/'.join(file_path.split('.')[0].split('/')[:-1])
                    args.data_path = file
                    print('root', args.root_path)
                else:
                    args.data = csv_file
                    args.root_path = root
                    args.data_path = file
                    print('root', args.root_path)
                    print('data_path', args.data_path)
                main(device, config, save_diri, comet_logger, data_init_loc, args)

    