import os, sys, fnmatch
sys.path.append('./xds/xds_python/')

import argparse
import torch
import random
import numpy as np
from sklearn.model_selection import train_test_split
from pathlib import Path

from preprocess.data_loader import load_monkey_spike, load_spike_generator, load_sample_generator, concatenate_spike_token
from utils.padding import spike_preprocessing, spike_zero_padding
from config.align_config import AlignConfig
from experiment.fine_tune import fine_tuning_stage
from align.mmd import MMD_loss

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

## fine-tune

if __name__ == '__main__':
    # parse arguments
    parser = argparse.ArgumentParser(description='finetune')
    parser.add_argument('-cuda_device', type=str, default='0', help='which gpu to use ')
    
    args = parser.parse_args()

    # dataset
    # Jango_2015_isometric_wrist_task, Spike_ISO_2012, Mihili_CO_2014, Chewie_CO_2016, Mihili_RT_2013_2014
    data_path = './datasets/Mihili_RT_2013_2014/'
    NHP_id = 'Mihili'
    # save_NHP_id = NHP_id
    save_NHP_id = 'Mihili_RT'
    id_len, date_len = len(NHP_id), 8
    batch_size_limit = 10
    # l2 = 1 if NHP_id == 'Mihili' else 0
    mat_list = np.sort(fnmatch.filter(os.listdir(data_path), "*.mat")) # We sorted the files by nam
    
    # pre-process
    bin_size, smooth_size = 0.05, 0.1
    start_time = 'gocue_time'

    # max channel names
    max_unit_names = []
    for idx in range(len(mat_list)):
    # for idx in src_session_idx:
        src_data_date = mat_list[idx][id_len+1:id_len+date_len+1]   
        _, _, unit_names = load_monkey_spike(data_path, src_data_date, bin_size, smooth_size, start_time, NHP_id)
        
        max_unit_names = np.sort(list(set(max_unit_names)|set(unit_names)))
    
    # align config setting
    # source data
    tgt_train_ratio_list = [0.02, 0.02, 0.03, 0.04, 0.06]
    # tgt_train_ratio_list = [0.0]
    aligner_method = 'likelihood'
    mmd_loss = None
    kernel_mul_dict={
        'Chewie':2.0,
        'Mihili': 1.0,
        'Mihili_RT': 2.0,
    }
    D_params, training_params = {}, {}
    if aligner_method == 'mmd' or aligner_method == 'c_mmd':
        kernel_mul, kernel_num = kernel_mul_dict[save_NHP_id], 5
        mmd_loss = MMD_loss(kernel_mul, kernel_num)
    elif aligner_method == 'gan':
        # set D_params & G_params
        training_params['D_lr'] = 1e-3
        training_params['drop_out_D'] = 0.2
        pass

    for tgt_train_ratio in tgt_train_ratio_list:
        # tgt_train_ratio = 0.02

        # load src and tgt datas
        # single-session and multi-session
        fine_tuning_step = 25
        for src_sess_last in range(0, len(mat_list)):
            # src_sess_last = 5
            # src_session_idx = [0, 1, 2]
            # src_session_idx = list(range(src_sess_last))
            src_session_idx = [src_sess_last]

            # seed list
            seed_list = [1, 0, 1, 2, 3, 4]
            training_step = 3500
            # training_step = 4000
            total_valid_r2_list = []
            total_time_list, total_time_per_epoch_list = [], []
            for seed in seed_list:
                tgt_sess_list = ['Day']
                best_valid_r2_list = ['r2 score']
                total_fine_tuning_time_list, fine_tuning_time_per_epoch_list = ['total time'], ['time per epoch'] 

                # tgt_sess_start = src_session_idx[-1]+1
                for tgt_idx in range(6, len(mat_list)):
                    if tgt_idx in src_session_idx:
                        continue

                    setup_seed(seed)

                    # target session
                    tgt_data_date = mat_list[tgt_idx][id_len+1:id_len+date_len+1]
                    tgt_day_spike, tgt_day_cursor, tgt_unit_names = load_monkey_spike(data_path, tgt_data_date, bin_size, smooth_size, start_time, NHP_id)
                    max_unit_names = np.sort(list(set(max_unit_names)|set(tgt_unit_names)))

                    # 0: position 1: velocity 2: accelerated speed
                    cur_idx = 1

                    src_day_spike, src_day_cursor_pos_xy = [], []
                    for src_idx in src_session_idx:
                        data_date = mat_list[src_idx][id_len+1:id_len+date_len+1]

                        print("current srouce recording date: %s" % data_date)
                        print("data preparing...")
                        day_spike, day_cursor, unit_names = load_monkey_spike(data_path, data_date, bin_size, smooth_size, start_time, NHP_id)
                        day_cursor_pos_xy = day_cursor[cur_idx]

                        # zero-padding
                        # day_spike, tgt_day_spike = spike_preprocessing(unit_names, tgt_unit_names, day_spike, tgt_day_spike)
                        day_spike_tmp = spike_zero_padding(max_unit_names, unit_names, day_spike)

                        src_day_spike.extend(day_spike_tmp)
                        src_day_cursor_pos_xy.extend(day_cursor_pos_xy)

                    # zero-padding
                    print("current target recording date: %s" % tgt_data_date)
                    tgt_day_cursor_pos_xy = tgt_day_cursor[cur_idx]
                    tgt_day_spike = spike_zero_padding(max_unit_names, tgt_unit_names, tgt_day_spike)

                    # save result folder
                    src_name = 'src'
                    for idx in src_session_idx:
                        src_name += '_' + str(idx)
                    # load_model_pth = './pre_train/{}/{}/{}'.format(save_NHP_id, src_name, tgt_data_date)
                    load_model_pth = './pre_train/{}/{}'.format(save_NHP_id, src_name)
                    # load model
                    load_model_file = '{}/{}/{}'.format(load_model_pth, seed, 'SiT_flow_full_model_{}.pt'.format(training_step))
                    from pathlib import Path
                    if not Path(load_model_file).is_file():
                        raise ValueError("File not found: {}".format(load_model_file))
                    pre_train_data = torch.load(
                        load_model_file,
                        map_location=torch.device('cuda:{}'.format(args.cuda_device))
                    )

                    # load pre-train
                    sit_flow = pre_train_data['model']
                    pre_train_config = pre_train_data['save_train_config']
                
                    # pre-pare target window
                    tgt_valid_ratio = 0.9 if NHP_id == 'Chewie' else 0.8
                    # tgt_valid_ratio = 0.2
                    tgt_day_spike_train, tgt_day_spike_test, tgt_day_cursor_train, tgt_day_cursor_test = \
                        train_test_split(tgt_day_spike, tgt_day_cursor_pos_xy, test_size=tgt_valid_ratio, random_state=0)
                    if tgt_train_ratio > 0:
                        tgt_train_num = int(len(tgt_day_spike)*tgt_train_ratio)
                        tgt_day_spike_train, tgt_day_cursor_train = tgt_day_spike_train[:tgt_train_num], tgt_day_cursor_train[:tgt_train_num]
                    else:
                        fine_tuning_step = 1

                    # training for fine-tuning
                    window_size = sit_flow.window_size
                    fourier_flag = False
                    # fourier_flag = sit_flow.fourier_flag
                    batch_size = 256
                    shuffle_flag = False

                    src_trial_num = len(src_day_spike)
                    train_day_spike, train_day_cursor = src_day_spike[:src_trial_num], src_day_cursor_pos_xy[:src_trial_num]
                    src_train_generator = load_spike_generator(
                        day_spike=train_day_spike,
                        day_cursor=train_day_cursor,
                        window_size=window_size,
                        batch_size=batch_size,
                        is_shuffle=shuffle_flag,
                        fourier_flag=fourier_flag,
                    )

                    tgt_train_generator = load_spike_generator(
                        day_spike=tgt_day_spike_train,
                        day_cursor=tgt_day_cursor_train,
                        window_size=window_size,
                        batch_size=batch_size,
                        is_shuffle=shuffle_flag,
                        fourier_flag=fourier_flag,
                    )

                    # validation with target data
                    tgt_day_spike_test_format, tgt_cursor_test_format = \
                        concatenate_spike_token(tgt_day_spike_test, tgt_day_cursor_test, window_size, is_shuffle=shuffle_flag, fourier_flag=fourier_flag)
                    tgt_test_generator = (tgt_day_spike_test_format, tgt_cursor_test_format)

                    device = torch.device("cuda:" + args.cuda_device) if torch.cuda.is_available() else torch.device('cpu')

                    D_params['hidden_dim'] = sit_flow.hidden_size
                    learning_rate = 1e-4 if aligner_method == 'likelihood' else 5e-4
                    align_config = AlignConfig(
                        cfg_scale=1.0,
                        num_sampling_steps=2,
                        aligner_method=aligner_method,
                        mmd_loss=mmd_loss,

                        # gan config
                        D_params=D_params,
                        training_params=training_params,

                        device=device,
                        src_train_generator=src_train_generator,
                        train_generator=tgt_train_generator,
                        valid_generator=tgt_test_generator,
                        model=sit_flow,
                        pre_train_config=pre_train_config,

                        tgt_train_ratio=tgt_train_ratio,
                        fine_tuning_step=fine_tuning_step,
                        learning_rate=learning_rate,
                        # learning_rate=1e-4,
                        weight_decay=1e-5,
                        sample_every=1,
                        sampling_method='euler',
                        # ema_decay=0.99,
                    )
                    
                    # fine-tuning
                    (best_valid_r2, valid_r2_final, valid_r2_curve), best_valid_model, fine_tune_results = fine_tuning_stage(align_config)

                    tgt_sess_list.append(tgt_data_date)
                    best_valid_r2_list.append(best_valid_r2)
                    total_fine_tuning_time_list.append(fine_tune_results['total_fine_tuning_time'])
                    fine_tuning_time_per_epoch_list.append(fine_tune_results['fine_tuning_time_per_epoch'])

                    import os
                    # save alignment model
                    aligner_method_name = aligner_method if tgt_train_ratio > 0 else 'no_align'
                    save_model_pth = './fine_tune/{}/{}/{}/{}/{}/{}/'.format(aligner_method_name, save_NHP_id, str(tgt_train_ratio), src_name, tgt_data_date, seed)
                    # save_model_pth = './fine_tune/ablation/align/{}/{}/{}/{}/{}/{}/'.format(aligner_method,save_NHP_id, str(tgt_train_ratio), src_name, tgt_data_date, seed)
                    if not os.path.exists(save_model_pth):
                        os.makedirs(save_model_pth)
                    save_model_file = '{}/SiT_flow_fine_tune_{}.pt'.format(save_model_pth, align_config.fine_tuning_step)     
                    torch.save({
                        'model': best_valid_model,
                        # 'save_train_config': align_config,
                        'best_valid_r2': best_valid_r2,
                        'valid_r2_final': valid_r2_final,
                        'valid_r2_curve': valid_r2_curve,
                        'valid_log_curve': fine_tune_results['valid_log_curve'],
                        'tgt_day_spike_test': tgt_day_spike_test,
                        'tgt_day_cursor_test': tgt_day_cursor_test,
                    }, save_model_file)
                
                total_valid_r2_list.append(best_valid_r2_list[1:])
                total_time_list.append(total_fine_tuning_time_list[1:])
                total_time_per_epoch_list.append(fine_tuning_time_per_epoch_list[1:])

                # save alignment results
                save_rel_pth = './rel/fine_tune/{}/{}/{}/{}/'.format(aligner_method_name, save_NHP_id, str(tgt_train_ratio), src_name)
                if not os.path.exists(save_rel_pth):
                    os.makedirs(save_rel_pth)
                save_rel_file = '{}/fine_tune_results_src_{}_{}.xlsx'.format(save_rel_pth, src_sess_last, str(tgt_train_ratio))
                
                import openpyxl
                rel_file = Path(save_rel_file)
                if not rel_file.exists():
                    wb = openpyxl.Workbook()
                else:
                    wb = openpyxl.load_workbook(save_rel_file)
                
                if save_NHP_id not in wb.sheetnames:
                    wb.create_sheet(title=save_NHP_id)
                ws = wb[save_NHP_id]
                ws.append([])
                ws.append(['seed={}'.format(str(seed))])
                ws.append(tgt_sess_list)
                ws.append(best_valid_r2_list)
                ws.append(total_fine_tuning_time_list)
                ws.append(fine_tuning_time_per_epoch_list)
                wb.save(save_rel_file)
                
            # average over random seeds
            total_r2_score = np.array(total_valid_r2_list)
            r2_score_avg = np.mean(total_r2_score, axis=0)
            r2_score_std = np.std(total_r2_score, axis=0)
            total_time = np.array(total_time_list)
            total_time_avg = np.mean(total_time, axis=0)
            total_time_per_epoch = np.array(total_time_per_epoch_list)
            total_time_per_epoch_avg = np.mean(total_time_per_epoch, axis=0)

            wb = openpyxl.load_workbook(save_rel_file)
            ws = wb[save_NHP_id]
            ws.append([])
            ws.append(tgt_sess_list)
            ws.append(['r2 avg'] + r2_score_avg.tolist())
            ws.append(['r2 std'] + r2_score_std.tolist())
            ws.append(['total time avg'] + total_time_avg.tolist())
            ws.append(['time per epoch avg'] + total_time_per_epoch_avg.tolist())
            wb.save(save_rel_file)