import os, sys, fnmatch
sys.path.append('./xds/xds_python/')

import numpy as np
import random
from copy import deepcopy
import torch
import argparse
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split

from preprocess.data_loader import load_monkey_spike, load_spike_generator, load_sample_generator, concatenate_spike_token
from utils.padding import spike_preprocessing, spike_zero_padding
from model.vanilla_iTransfomer import ConditionModel
from config.model_config import ModelConfig
from config.train_config import TrainConfig, TrainEmbedderConfig
from flow.models.SiT_models import SiT
from experiment.train import training_stage, train_dynamic_embedder

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

## pre-train

if __name__ == '__main__':
    # parse arguments
    parser = argparse.ArgumentParser(description='train')
    parser.add_argument('-cuda_device', type=str, default='0', help='which gpu to use ')
    
    args = parser.parse_args()

    # dataset
    # Jango_2015_isometric_wrist_task, Spike_ISO_2012, Mihili_CO_2014, Chewie_CO_2016, Mihili_RT_2013_2014
    data_path = './datasets/Mihili_CO_2014/'
    NHP_id = 'Mihili'
    save_NHP_id = NHP_id
    # save_NHP_id = 'Mihili_RT'
    id_len, date_len = len(NHP_id), 8
    batch_size_limit = 10
    # l2 = 1 if NHP_id == 'Mihili' else 0
    mat_list = np.sort(fnmatch.filter(os.listdir(data_path), "*.mat")) # We sorted the files by nam
    
    # pre-process
    bin_size, smooth_size = 0.05, 0.1
    start_time = 'gocue_time'
    
    #  hyper-parameters
    # Mihili: 5
    window_size = 6 if NHP_id == 'Chewie' else 5
    # Mihili: 32
    hidden_size = 64 if NHP_id == 'Chewie' else 32
    invert_flag = True
    src_train_ratio = 0.9
    training_step = 3500

    ## source session pre-train

    # load src and tgt datas
    # single-session and multi-session
    # max channel names
    max_unit_names = []
    for idx in range(len(mat_list)):
    # for idx in src_session_idx:
        src_data_date = mat_list[idx][id_len+1:id_len+date_len+1]   
        _, _, unit_names = load_monkey_spike(data_path, src_data_date, bin_size, smooth_size, start_time, NHP_id)
        
        max_unit_names = np.sort(list(set(max_unit_names)|set(unit_names)))
    
    for src_sess_last in range(0, len(mat_list)):
        # src_sess_last = 5
        # src_session_idx = [0, 1, 2]
        # src_session_idx = list(range(src_sess_last))
        src_session_idx = [src_sess_last, src_sess_last+1]
        
        # tgt_sess_start, tgt_sess_end = src_session_idx[-1]+1, min(src_session_idx[-1]+2, len(mat_list))
        # for tgt_idx in range(tgt_sess_start, tgt_sess_end):
        # target session
        tgt_idx = src_sess_last+1 if src_sess_last+1 < len(mat_list) else src_sess_last
        tgt_data_date = mat_list[tgt_idx][id_len+1:id_len+date_len+1]
        tgt_day_spike, tgt_day_cursor, tgt_unit_names = load_monkey_spike(data_path, tgt_data_date, bin_size, smooth_size, start_time, NHP_id)
        max_unit_names = np.sort(list(set(max_unit_names)|set(tgt_unit_names)))

        # 0: position 1: velocity 2: accelerated speed
        cur_idx = 1

        src_day_spike, src_day_cursor_pos_xy = [], []
        for src_idx in src_session_idx:
            data_date = mat_list[src_idx][id_len+1:id_len+date_len+1]

            print("current srouce recording date: %s" % data_date)
            print("data preparing...")
            day_spike, day_cursor, unit_names = load_monkey_spike(data_path, data_date, bin_size, smooth_size, start_time, NHP_id)
            day_cursor_pos_xy = day_cursor[cur_idx]

            # zero-padding
            # day_spike, tgt_day_spike = spike_preprocessing(unit_names, tgt_unit_names, day_spike, tgt_day_spike)
            day_spike_tmp = spike_zero_padding(max_unit_names, unit_names, day_spike)

            src_day_spike.extend(day_spike_tmp)
            src_day_cursor_pos_xy.extend(day_cursor_pos_xy)

        # zero-padding
        print("current target recording date: %s" % tgt_data_date)
        tgt_day_cursor_pos_xy = tgt_day_cursor[cur_idx]
        tgt_day_spike = spike_zero_padding(max_unit_names, tgt_unit_names, tgt_day_spike)

        # save model folder
        src_name = 'src'
        for idx in src_session_idx:
            src_name += '_' + str(idx)

        # train/test split
        seed_list = [0, 1, 2, 3, 4]
        # seed_list = [0, 1]
        best_valid_r2_list = ['valid r2']
        total_train_time_list, train_time_per_epoch = ['total time'], ['time per epoch']
        lya_max_list = ['lya max']
        for seed in seed_list:
            # seed = 3
            setup_seed(seed=seed)
            # train_day_spike, train_day_cursor = src_day_spike, src_day_cursor_pos_xy
            train_day_spike, test_day_spike, train_day_cursor, test_day_cursor = train_test_split(src_day_spike, src_day_cursor_pos_xy, test_size=1.0-src_train_ratio, random_state=0)
            
            # tokenize
            fourier_flag = False
            batch_size = 256
            train_generator = load_spike_generator(day_spike=train_day_spike,
                                                day_cursor=train_day_cursor,
                                                window_size=window_size,
                                                batch_size=batch_size,
                                                is_shuffle=False,
                                                fourier_flag=fourier_flag)
            valid_day_spike_format, valid_day_cursor_format = concatenate_spike_token(test_day_spike, test_day_cursor, window_size, is_shuffle=False, fourier_flag=fourier_flag)
            valid_generator = (valid_day_spike_format, valid_day_cursor_format)

            test_day_spike_format, test_day_cursor_format = concatenate_spike_token(tgt_day_spike, tgt_day_cursor_pos_xy, window_size, is_shuffle=False, fourier_flag=fourier_flag)
            test_generator = (test_day_spike_format, test_day_cursor_format)

            device = torch.device("cuda:" + args.cuda_device) if torch.cuda.is_available() else torch.device('cpu')

            # vanilla Transformer (encoder-only)
            # context_size = window_size*2 if fourier_flag else window_size
            context_size = window_size
            n_chan  = train_day_spike[0].shape[1] if not invert_flag else context_size
            seq_len = context_size if not invert_flag else train_day_spike[0].shape[1]
            configs = ModelConfig(
                seq_len=seq_len,
                enc_in=n_chan,
                training_step=training_step,
                e_layers=2,
                factor=1,
            )
            transformer_model = ConditionModel(configs)

            # SiT model settings
            flow_model = SiT(
                in_channels=n_chan if not invert_flag else seq_len,
                window_size=context_size,
                hidden_size=hidden_size,
                out_dim=2,
                # Mihili: 5
                depth=5,
                # Mihili: 2.0
                mlp_ratio=2.0,
                model_config=configs,
                invert_flag=invert_flag,
            )

            # flow_model.dynamic_embedder.transformer_model = torch.load('./pre_train/src_0_1/Mihili/seed/dynamic_transformer_model_0730_0.pkl')
            # flow_model.dynamic_embedder.requires_grad_(False)
            # torch.save(flow_model.linear_encoder, './pre_train/exp_manifold_seed_0.pkl')
            # flow_model.linear_encoder = torch.load('./pre_train/exp_manifold_seed_0.pkl')
            training_step = 600 if NHP_id == 'Chewie' else 200
            train_embdder_config = TrainEmbedderConfig(
                device=device,
                train_generator=train_generator,
                # valid_generator=test_generator,
                valid_generator=valid_generator,
                model=flow_model.dynamic_embedder.transformer_model,
                invert_flag=invert_flag,
                # Mihili: 200
                training_step=training_step,
                update_flag=True,
            )
            initial_model = train_dynamic_embedder(train_embdder_config)
            # torch.save(initial_model, './pre_train/SiT_initial/dynamic_transformer_model.pkl')
            # initial_model = torch.load('./pre_train/SiT_initial/dynamic_transformer_model.pkl')
            # flow_model.dynamic_embedder.transformer_model = deepcopy(initial_model)

            # training process settings
            '''
            'dopri8': Dopri8Solver,
            'dopri5': Dopri5Solver,
            'bosh3': Bosh3Solver,
            'fehlberg2': Fehlberg2,
            'adaptive_heun': AdaptiveHeunSolver,
            'euler': Euler,
            'midpoint': Midpoint,
            'heun3': Heun3,
            'rk4': RK4,
            '''
            
            train_config = TrainConfig(
                cfg_scale=1.0,
                device=device,
                train_generator=train_generator,
                # valid_generator=test_generator,
                valid_generator=valid_generator,
                model=flow_model,
                path_type="Linear",
                loss_weight=None,
                prediction="velocity",
                training_step=3500,
                weight_decay=1e-5,
                num_sampling_steps=2,
                sampling_method="euler", # ["dopri5", "rk4"]
                sample_every=20,
                ema_decay=0.99,
            )

            (best_valid_r2, valid_r2_epoch, valid_r2_curve), model, pre_train_results = training_stage(train_config)
            best_valid_r2_list.append(valid_r2_epoch)

            import os
            # save_model_pth = './pre_train/{}/{}/{}/'.format(save_NHP_id, src_name, tgt_data_date)
            save_model_pth = './pre_train/{}/{}/'.format(save_NHP_id, src_name)
            # save results folder
            save_rel_pth = './rel/pre_train/{}/{}/'.format(save_NHP_id, src_name)

            save_model_pth = '{}/{}/'.format(save_model_pth, str(seed))
            if not os.path.exists(save_model_pth):
                os.makedirs(save_model_pth)
            save_model_file = os.path.join(save_model_pth, 'SiT_flow_full_model_{}.pt'.format(train_config.training_step))

            save_train_config = {
                # sampler settings
                'fourier_flag': fourier_flag,
                'path_type': train_config.path_type,
                'prediction': train_config.prediction,
                'loss_weight': train_config.loss_weight,
                'train_eps': train_config.train_eps,
                'sample_eps': train_config.sample_eps,
                # other
                'training_step': train_config.training_step,
                'sampling_method': train_config.sampling_method,
                'num_sampling_steps': train_config.num_sampling_steps,
            }
            torch.save({
                'model': model,
                'save_train_config': save_train_config,
                'best_valid_r2': best_valid_r2,
                'valid_r2_epoch': valid_r2_epoch,
                'valid_r2_curve': valid_r2_curve,
                'total_train_time': pre_train_results['total_train_time'],
                'train_time_per_epoch': pre_train_results['train_time_per_epoch'],
                'lya_max_final': pre_train_results['lya_max_final'],
            }, save_model_file)
            
            # save pre-train results
            total_train_time_list.append(pre_train_results['total_train_time'])
            train_time_per_epoch.append(pre_train_results['train_time_per_epoch'])
            lya_max_list.append(pre_train_results['lya_max_final'])

        # pre-train results
        save_rel_file = '{}/pre_train_results_src_{}.xlsx'.format(save_rel_pth, src_sess_last)
        import openpyxl
        from pathlib import Path
        if not os.path.exists(save_rel_pth):
            os.makedirs(save_rel_pth)
        rel_file = Path(save_rel_file)
        if not rel_file.exists():
            wb = openpyxl.Workbook()
        else:
            wb = openpyxl.load_workbook(save_rel_file)
        
        if save_NHP_id not in wb.sheetnames:
            ws = wb.create_sheet(save_NHP_id)
        ws = wb[save_NHP_id]
        ws.append([])
        ws.append(['src sess last = {}'.format(src_sess_last)])
        
        def add_avg_and_std(add_list):
            add_avg, add_std = np.mean(add_list[1:]), np.std(add_list[1:])
            add_list.append(add_avg)
            add_list.append(add_std)
            return add_list
        best_valid_r2_list = add_avg_and_std(best_valid_r2_list)
        total_train_time_list = add_avg_and_std(total_train_time_list)
        train_time_per_epoch = add_avg_and_std(train_time_per_epoch)
        lya_max_np  = np.array(lya_max_list[1:])
        lya_max_list.append(np.mean(lya_max_np[lya_max_np<1e-3]))

        ws.append(best_valid_r2_list)
        ws.append(total_train_time_list)
        ws.append(train_time_per_epoch)
        ws.append(lya_max_list)
        wb.save(save_rel_file)