import os
import pickle
import numpy as np
import argparse
import matplotlib.pyplot as plt
from sklearn.linear_model import Ridge
from sklearn.preprocessing import PolynomialFeatures, SplineTransformer
from sklearn.pipeline import make_pipeline

import misc


def main():
    # parse argument and read data
    parser = argparse.ArgumentParser()
    parser.add_argument('--pkl-path', type=str, required=True)
    parser.add_argument('--mode', type=str, default='neuron_to_action')
    parser.add_argument('--out-dir', type=str, default=None)
    parser.add_argument('--seg-len', type=int, default=30)
    args = parser.parse_args()

    with open(args.pkl_path, 'rb') as f:
        data = pickle.load(f)

    f_prefix = os.path.splitext(args.pkl_path)[0]
    config = misc.load_yaml(f_prefix + '_eval_config.yaml')

    n_episodes = len(data)
    print(f'Total number of episodes: {n_episodes}')

    # process data into segments
    all_seg_data = dict(vec_obs=[], action=[])
    for ep_i, ep_data in enumerate(data):
        # obs, action, state, next_obs, reward, done, info = step_data

        n_splits = len(ep_data) // args.seg_len + 1
        splitted_data = np.array_split(ep_data, n_splits)
        all_seg_data['vec_obs'].append([np.stack([vv[0][1] for vv in v]) for v in splitted_data])
        all_seg_data['action'].append([np.stack([vv[1] for vv in v]) for v in splitted_data])

    # fit polynomials
    piecewise_model = dict()
    dt = 1 / 50.
    for k, all_seg_data_k in all_seg_data.items():
        if k == 'vec_obs': continue # DEBUG
        piecewise_model[k] = [] # for all episodes
        for ep_i, ep_split_data in enumerate(all_seg_data_k):
            feat_dim = ep_split_data[0].shape[1]
            fig, axes = plt.subplots(feat_dim, 1, figsize=(6.4*2, 4.8*feat_dim))
            current_t = 0.
            piecewise_model[k].append([]) # append for a new episode
            for split_data in ep_split_data:
                ts = np.arange(split_data.shape[0])[:,None] * dt

                model = make_pipeline(PolynomialFeatures(degree=3), Ridge(alpha=1e-3))
                model.fit(ts, split_data)
                pred_split_data = model.predict(ts)
                piecewise_model[k][-1].append(model)

                for feat_i in range(feat_dim):
                    ax = axes[feat_i]
                    ax.plot(current_t + ts, split_data[:,feat_i], c='b', linewidth=3)
                    ax.plot(current_t + ts, pred_split_data[:,feat_i], c='g', linewidth=2, linestyle='dashed')

                current_t += ts[-1] + dt

            fig.tight_layout()
            if args.out_dir:
                fig.savefig(f'{args.out_dir}/{k}_ep_{ep_i:02d}.png')
            import pdb; pdb.set_trace() # DEBUG


if __name__ == "__main__":
    main()
