from lqr_env import LQREnv, RandController, LQRController
import darkroom_env
import lqr_env
import bandit_env
import numpy as np
import os
import pickle
from IPython import embed
import scipy
import matplotlib.pyplot as plt
import os
import random
from evals import eval_bandit

max_probs = []

def rollin(env):
    raise NotImplementedError # old lqr code
    H = env.H
    xs, us, xps, rs = [], [], [], []

    for h in range(H):
        x = np.random.uniform(-1, 1, env.dx)
        u = np.random.uniform(-.5, .5, env.du)
        xp, r = env.transit(x, u)

        xs.append(x)
        us.append(u)
        xps.append(xp)
        rs.append(r)

    xs, us, xps, rs = np.array(xs), np.array(us), np.array(xps), np.array(rs)
    return xs, us, xps, rs




def rollin_dr(env, rollin='uniform'):
    H = env.H

    if rollin == 'uniform':
        xs, us, xps, rs = [], [], [], []

        for h in range(H):
            x = env.sample_x()
            u = env.sample_u()

            xp, r = env.transit(x, u)

            xs.append(x)
            us.append(u)
            xps.append(xp)
            rs.append(r)
    elif rollin == 'stitch':
        xs, us, xps, rs = [], [], [], []

        for h in range(H):
            x = env.sample_stitch_x()
            u = env.sample_stitch_opt_a(x)

            xp, r = env.transit(x, u)

            xs.append(x)
            us.append(u)
            xps.append(xp)
            rs.append(r)
    elif rollin == 'expert':
        raise NotImplementedError
    else:
        raise NotImplementedError


    xs, us, xps, rs = np.array(xs), np.array(us), np.array(xps), np.array(rs)
    return xs, us, xps, rs

def rollin_bandit(env, cov=0.0, orig=2):
    H = env.H_context
    opt_a_index = env.opt_a_index
    xs, us, xps, rs = [], [], [], []

    # hierarchical bayesian model that generates many very different looking bandit problems

    if orig == 9:
        # baseline that uses cov
        alpha = np.ones(env.dim)
        probs = np.random.dirichlet(alpha)
        probs2 = np.zeros(env.dim)
        probs2[opt_a_index] = 1.0
        probs = (1 - cov) * probs + cov * probs2

    elif orig == 10:
        cov = np.random.choice([0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0])
        alpha = np.ones(env.dim)
        probs = np.random.dirichlet(alpha)
        probs2 = np.zeros(env.dim)
        probs2[opt_a_index] = 1.0
        probs = (1 - cov) * probs + cov * probs2


    elif orig == 14:
        cov = np.random.choice([0.0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0])
        alpha = np.ones(env.dim)
        probs = np.random.dirichlet(alpha)
        probs2 = np.zeros(env.dim)
        rand_index = np.random.choice(np.arange(env.dim))
        probs2[rand_index] = 1.0
        probs = (1 - cov) * probs + cov * probs2



    else: raise NotImplementedError


    max_probs.append(np.max(probs))



    for h in range(H):
        x = np.array([1])
        u = np.zeros(env.dim)
        i = np.random.choice(np.arange(env.dim), p=probs)
        u[i] = 1.0
        xp, r = env.transit(x, u)

        xs.append(x)
        us.append(u)
        xps.append(xp)
        rs.append(r)
    
    xs, us, xps, rs = np.array(xs), np.array(us), np.array(xps), np.array(rs)
    return xs, us, xps, rs



def rollin_linear_bandit(env, warm_start=False):
    H = env.H_context
    xs, us, xps, rs = [], [], [], []
    
    # thompson
    prior_mean = 0.0
    prior_var = 12.0 / 12.0
    
    thmp2 = bandit_env.ThompsonSamplingPolicy2(env, std=env.var, sample=True, prior_mean=prior_mean, prior_var=prior_var, warm_start=warm_start)


    kwargs = {
        'H': H,
        'horizon': H,
    }
    _, meta= eval_bandit.deploy_online(env, thmp2, **kwargs)
    xs = meta['xs']
    us = meta['us']
    xps = meta['xps']
    rs = meta['rs'][:,0]

    return xs, us, xps, rs



def rollin_linear_bandit_vec(envs, warm_start=False):
    H = envs[0].H_context
    xs, us, xps, rs = [], [], [], []

    # thompson
    prior_mean = 0.0
    prior_var = 12.0 / 12.0

    thmp2 = bandit_env.ThompsonSamplingPolicy2(
        envs[0],
        std=envs[0].var,
        sample=True,
        prior_mean=prior_mean,
        prior_var=prior_var,
        warm_start=warm_start,
        batch_size=len(envs))

    kwargs = {
        'H': H,
        'horizon': H,
    }

    vec_env = bandit_env.BanditEnvVec(envs)
    _, meta= eval_bandit.deploy_online_numpy_vec(vec_env, thmp2, **kwargs)
    xs = meta['xs']
    us = meta['us']
    xps = meta['xps']
    rs = meta['rs'][:, :, 0]

    return xs, us, xps, rs


def rollin_bandit_topk(env):
    H = env.H_context
    xs, us, xps, rs = [], [], [], []

    for h in range(H):
        x = np.array([1])
        u = np.zeros(env.dim)

        indices = np.arange(env.dim)
        np.random.shuffle(indices)
        indices = indices[:env.k]
        u[indices] = 1.0

        xp, r = env.transit(x, u)

        xs.append(x)
        us.append(u)
        xps.append(xp)
        rs.append(r)

    xs, us, xps, rs = np.array(xs), np.array(us), np.array(xps), np.array(rs)
    return xs, us, xps, rs


def generate_histories(n_envs, n_hists, n_samples, H):
    trajs = []

    # iterate over environments 
    for i in range(n_envs):
        if i % 100 == 0:
            print(f"Env: {i}")

        A, B, Q, R = lqr_env.sample(dx, du)

        opt = LQRController(A, B, Q, R)
        rand = RandController(A, B, Q, R)
        env = LQREnv(A, B, Q, R, H)

        for j in range(n_hists):
            rollin_xs, rollin_us, rollin_xps, rollin_rs = rollin(env)
            for k in range(n_samples):
                x = np.random.uniform(-1, 1, dx)
                u = opt.act(x)

                traj = {
                    'state': x,
                    'action': u,
                    'rollin_xs': rollin_xs,
                    'rollin_us': rollin_us,
                    'rollin_xps': rollin_xps,
                    'rollin_rs': rollin_rs,
                    'Q': Q,
                    'matrices': (A, B, Q, R),
                }
                trajs.append(traj)

    return trajs


def generate_dr_histories(n_envs, n_hists, n_samples, H, dim):
    envs = [darkroom_env.sample(dim, H) for i in range(n_envs)]

    trajs = generate_dr_histories_from_envs(
        envs=envs,
        n_hists=n_hists,
        n_samples=n_samples,
        H=H,
        dim=dim)

    return trajs


def generate_dr_histories_for_goals(goals, n_hists, n_samples, H, dim, rollin='uniform'):
    envs = [darkroom_env.DarkroomEnv(dim, goal, H) for goal in goals]

    trajs = generate_dr_histories_from_envs(
        envs=envs,
        n_hists=n_hists,
        n_samples=n_samples,
        H=H,
        dim=dim,
        rollin=rollin
        )

    return trajs


def generate_dr_histories_from_envs(envs, n_hists, n_samples, H, dim, rollin='uniform'):
    trajs = []
    for env in envs:
        for j in range(n_hists):
            rollin_xs, rollin_us, rollin_xps, rollin_rs = rollin_dr(env, rollin=rollin)
            for k in range(n_samples):
                x = env.sample_x()
                u = env.opt_a(x)

                traj = {
                    'state': x,
                    'action': u,
                    'rollin_xs': rollin_xs,
                    'rollin_us': rollin_us,
                    'rollin_xps': rollin_xps,
                    'rollin_rs': rollin_rs,
                    'goal': env.goal,
                }
                trajs.append(traj)

    return trajs


def generate_dr_stitch_histories_for_goals(goals, n_hists, n_samples, H, dim, rollin='uniform', eval=False):
    envs = [darkroom_env.DarkroomEnvStitch(dim, goal, H, eval=eval) for goal in goals]

    trajs = generate_dr_stitch_histories_from_envs(
        envs=envs,
        n_hists=n_hists,
        n_samples=n_samples,
        H=H,
        dim=dim,
        rollin='stitch' if eval else 'uniform',
    )

    return trajs


def generate_dr_stitch_histories_from_envs(envs, n_hists, n_samples, H, dim, rollin='uniform'):
    trajs = []
    for env in envs:
        for j in range(n_hists):
            rollin_xs, rollin_us, rollin_xps, rollin_rs = rollin_dr(env, rollin=rollin)
            for k in range(n_samples):
                x = env.sample_opt_x()
                u = env.opt_a(x)

                traj = {
                    'state': x,
                    'action': u,
                    'rollin_xs': rollin_xs,
                    'rollin_us': rollin_us,
                    'rollin_xps': rollin_xps,
                    'rollin_rs': rollin_rs,
                    'goal': env.goal,
                }
                trajs.append(traj)

    return trajs


def generate_bandit_histories(n_envs, n_hists, n_samples, H, dim, var=0.0, cov=0.0, type='uniform', orig=2):
    envs = [bandit_env.sample(dim, H, var, type=type) for i in range(n_envs)]

    trajs = generate_bandit_histories_from_envs(
        envs=envs,
        n_hists=n_hists,
        n_samples=n_samples,
        H=H,
        dim=dim,
        var=var,
        cov=cov,
        orig=orig,
        )

    return trajs


def generate_linear_bandit_histories(n_envs, n_hists, n_samples, H, dim, lin_d, var=0.0, warm_start=False):
    print("Generating linear bandit histories...")
    
    rng = np.random.RandomState(seed=1234)
    arms = rng.normal(size=(dim, lin_d)) / np.sqrt(lin_d)

    print("Sampling...")
    envs = [bandit_env.sample_linear(arms, H, var) for i in range(n_envs)]

    print("Generating histories...")
    rollin_xs_all, rollin_us_all, rollin_xps_all, rollin_rs_all = [], [], [], []
    for j in range(n_hists):
        rollin_xs, rollin_us, rollin_xps, rollin_rs = rollin_linear_bandit_vec(envs, warm_start=warm_start)
        rollin_xs_all.append(rollin_xs)
        rollin_us_all.append(rollin_us)
        rollin_xps_all.append(rollin_xps)
        rollin_rs_all.append(rollin_rs)
    rollin_xs_all = np.stack(rollin_xs_all, axis=1)
    rollin_us_all = np.stack(rollin_us_all, axis=1)
    rollin_xps_all = np.stack(rollin_xps_all, axis=1)
    rollin_rs_all = np.stack(rollin_rs_all, axis=1)

    trajs = []
    for i, env in enumerate(envs):
        print('Generating linear bandit histories for env {}/{}'.format(i+1, n_envs))
        for j in range(n_hists):
            # rollin_xs, rollin_us, rollin_xps, rollin_rs = rollin_linear_bandit(env, warm_start=warm_start)
            rollin_xs = rollin_xs_all[i, j]
            rollin_us = rollin_us_all[i, j]
            rollin_xps = rollin_xps_all[i, j]
            rollin_rs = rollin_rs_all[i, j]

            for k in range(n_samples):
                x = np.array([1])
                u = env.opt_a
                
                traj = {
                    'state': x,
                    'action': u,
                    'rollin_xs': rollin_xs,
                    'rollin_us': rollin_us,
                    'rollin_xps': rollin_xps,
                    'rollin_rs': rollin_rs,
                    'means': env.means,
                    'arms': env.arms,
                    'theta': env.theta,
                    'var': env.var,
                }
                trajs.append(traj)
    return trajs


def generate_topk_bandit_histories(n_envs, n_hists, n_samples, H, dim, var=0.0, k=1):
    envs = [bandit_env.sample_topk(dim, H, var, k) for i in range(n_envs)]

    trajs = generate_topk_bandit_histories_from_envs(
        envs=envs,
        n_hists=n_hists,
        n_samples=n_samples,
        H=H,
        dim=dim)

    return trajs


def generate_topk_bandit_histories_from_envs(envs, n_hists, n_samples, H, dim):
    trajs = []
    for env in envs:
        for j in range(n_hists):
            rollin_xs, rollin_us, rollin_xps, rollin_rs = rollin_bandit_topk(env)
            for k in range(n_samples):
                x = np.array([1])
                u = env.opt_a
                
                traj = {
                    'state': x,
                    'action': u,
                    'rollin_xs': rollin_xs,
                    'rollin_us': rollin_us,
                    'rollin_xps': rollin_xps,
                    'rollin_rs': rollin_rs,
                    'means': env.means,
                    'k': env.k,
                    'var': env.var,
                }
                trajs.append(traj)
    return trajs


def generate_bandit_histories_for_arms(arms, n_hists, n_samples, H, dim, var=0.0, cov=0.0):
    """Generates bandit histories for a list of arms.
    To generate multiple environments for a single arm, pass in [arm] * n_envs for arms.
    """
    envs = [bandit_env.sample_for_arm(arm, dim, H, var) for arm in arms]

    trajs = generate_bandit_histories_from_envs(
        envs=envs,
        n_hists=n_hists,
        n_samples=n_samples,
        H=H,
        dim=dim,
        var=var,
        cov=cov)

    return trajs


def generate_bandit_histories_from_envs(envs, n_hists, n_samples, H, dim, var=0.0, cov=0.0, orig=2):
    trajs = []

    # iterate over environments
    for env in envs:
        for j in range(n_hists):
            rollin_xs, rollin_us, rollin_xps, rollin_rs = rollin_bandit(env, cov=cov, orig=orig)
            for k in range(n_samples):
                x = np.array([1])
                u = env.opt_a

                traj = {
                    'state': x,
                    'action': u,
                    'rollin_xs': rollin_xs,
                    'rollin_us': rollin_us,
                    'rollin_xps': rollin_xps,
                    'rollin_rs': rollin_rs,
                    'Q': np.zeros(x.shape),
                    'means': env.means,
                }
                trajs.append(traj)
            
    return trajs
            

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--envs", type=int, required=False, default=1000, help="Envs")
    parser.add_argument("--hists", type=int, required=False, default=1, help="Histories")
    parser.add_argument("--samples", type=int, required=False, default=1, help="Samples")
    parser.add_argument("--H", type=int, required=False, default=10, help="Context horizon")
    parser.add_argument("--dim", type=int, required=False, default=1, help="Dimension")
    parser.add_argument("--lin_d", type=int, required=False, default=1, help="Linear Dimension")
    parser.add_argument("--k", type=int, required=False, default=1, help="Top k subset")
    parser.add_argument("--orig", type=int, required=False, default=2, help="Top k subset")
    parser.add_argument("--var", type=float, required=False, default=0.0, help="Bandit arm variance")
    parser.add_argument("--cov", type=float, required=False, default=0.0, help="Coverage of optimal arm")
    parser.add_argument("--env", type=str, required=True, help="Environment")

    args = vars(parser.parse_args())
    print("Args:")
    print(args)

    env = args['env']
    if env == 'bandit':
        bandit = True
    else:
        bandit = False

    n_envs = args['envs']
    n_hists = args['hists']
    n_samples = args['samples']
    H = args['H']
    orig = args['orig']

    dx = args['dim']
    du = args['dim']
    dim = args['dim']
    var = args['var']
    cov = args['cov']
    k = args['k']
    lin_d = args['lin_d']

    n_envs_tr = int(.8 * n_envs)
    n_envs_te = n_envs - n_envs_tr

    if env == 'bandit':
        train_trajs = generate_bandit_histories(n_envs_tr, n_hists, n_samples, H, dim, var=var, cov=cov, orig=orig)
        test_trajs = generate_bandit_histories(n_envs_te, n_hists, n_samples, H, dim, var=var, cov=cov, orig=orig)
        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_orig{orig}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_orig{orig}_test.pkl'
    
        # plot histogram of max probs
        plt.hist(max_probs, bins=100)
        os.makedirs('figs', exist_ok=True)
        plt.savefig('figs/max_probs.png')

        # compute fraction of max probs that are at least .99
        print(f'Fraction of max probs at least .99: {np.mean(np.array(max_probs) >= .99)}')


    elif env == 'bandit_ood':
        n_envs_first = int(.9 * n_envs_tr)
        n_envs_second = n_envs_tr - n_envs_first
        first_envs = list(range(dim // 2)) * (n_envs_first // (dim // 2))
        second_envs = list(range(dim // 2, dim)) * (n_envs_second // (dim // 2))
        train_trajs = generate_bandit_histories_for_arms(first_envs, n_hists, n_samples, H, dim, var=var, cov=cov)
        second_train_trajs = generate_bandit_histories_for_arms(second_envs, n_hists, n_samples, H, dim, var=var, cov=cov)
        train_trajs += second_train_trajs
        assert len(train_trajs) == n_envs_tr * n_hists * n_samples

        n_envs_first = int(.5 * n_envs_te)
        n_envs_second = n_envs_te - n_envs_first
        first_envs = list(range(dim // 2)) * (n_envs_first // (dim // 2))
        second_envs = list(range(dim // 2, dim)) * (n_envs_second // (dim // 2))
        test_trajs = generate_bandit_histories_for_arms(first_envs, n_hists, n_samples, H, dim, var=var, cov=cov)
        test_trajs += generate_bandit_histories_for_arms(second_envs, n_hists, n_samples, H, dim, var=var, cov=cov)
        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_test.pkl'
    
    elif env == 'bandit_topk':
        train_trajs = generate_topk_bandit_histories(n_envs_tr, n_hists, n_samples, H, dim, k=k, var=var)
        test_trajs = generate_topk_bandit_histories(n_envs_te, n_hists, n_samples, H, dim, k=k, var=var)
        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_k{k}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_k{k}_test.pkl'

    elif env == 'bandit_thompson':
        train_trajs = generate_bandit_histories(n_envs_tr, n_hists, n_samples, H, dim, var=var, cov=cov, type='bernoulli')
        test_trajs = generate_bandit_histories(n_envs_te, n_hists, n_samples, H, dim, var=var, cov=cov, type='bernoulli')
        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_test.pkl'

    elif env == 'linear_bandit':
        warm_start = False
        train_trajs = generate_linear_bandit_histories(n_envs_tr, n_hists, n_samples, H, dim, lin_d, var=var, warm_start=warm_start)
        test_trajs = generate_linear_bandit_histories(n_envs_te, n_hists, n_samples, H, dim, lin_d, var=var, warm_start=warm_start)
        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_dlin{lin_d}_var{var}_ws{warm_start}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_dlin{lin_d}_var{var}_ws{warm_start}_test.pkl'


    elif env == 'darkroom':
        train_trajs = generate_dr_histories(n_envs_tr, n_hists, n_samples, H, dim)
        test_trajs = generate_dr_histories(n_envs_te, n_hists, n_samples, H, dim)
        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'
    
    elif env == 'darkroom_heldout':
        goals = np.array([[(j, i) for i in range(dim)] for j in range(dim)]).reshape(-1, 2)
        np.random.RandomState(seed=0).shuffle(goals)
        train_test_split = int(.8 * len(goals))
        train_goals = goals[:train_test_split]
        test_goals = goals[train_test_split:]
        train_goals = np.repeat(train_goals, n_envs // (dim * dim), axis=0)
        test_goals = np.repeat(test_goals, n_envs // (dim * dim), axis=0)

        train_trajs = generate_dr_histories_for_goals(train_goals, n_hists, n_samples, H, dim)
        test_trajs = generate_dr_histories_for_goals(test_goals, n_hists, n_samples, H, dim)
        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'

    elif env == 'darkroom_stitch':
        goals = [np.array([dim // 2, dim - 1]), np.array([dim - 1, dim // 2])]
        train_goals = np.repeat(goals, n_envs_tr // len(goals), axis=0)
        test_goals = np.repeat(goals, n_envs_te // len(goals), axis=0)
        assert len(train_goals) + len(test_goals) == n_envs
        train_trajs = generate_dr_stitch_histories_for_goals(train_goals, n_hists, n_samples, H, dim, eval=False)
        test_trajs = generate_dr_stitch_histories_for_goals(test_goals, n_hists, n_samples, H, dim, eval=True)

        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'

    elif env == 'darkroom_expert':
        raise NotImplementedError
        goals = np.array([[(j, i) for i in range(dim)] for j in range(dim)]).reshape(-1, 2)
        np.random.RandomState(seed=0).shuffle(goals)
        train_test_split = int(.8 * len(goals))
        train_goals = goals[:train_test_split]
        test_goals = goals[train_test_split:]
        train_goals = np.repeat(train_goals, n_envs // (dim * dim), axis=0)
        test_goals = np.repeat(test_goals, n_envs // (dim * dim), axis=0)

        train_trajs = generate_dr_histories_for_goals(train_goals, n_hists, n_samples, H, dim, rollin='expert')
        test_trajs = generate_dr_histories_for_goals(test_goals, n_hists, n_samples, H, dim, rollin='expert')
        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'

    else:
        # we're no longer doing lqr
        raise NotImplementedError
        train_trajs = generate_histories(n_envs_tr, n_hists, n_samples, H)
        test_trajs = generate_histories(n_envs_te, n_hists, n_samples, H)
        filepath_tr = f'datasets/trajs_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
        filepath_te = f'datasets/trajs_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'

    if not os.path.exists('datasets'):
        os.makedirs('datasets', exist_ok=True)
    with open(filepath_tr, 'wb') as file:
        pickle.dump(train_trajs, file)
    with open(filepath_te, 'wb') as file:
        pickle.dump(test_trajs, file)

    print(f"Saved to {filepath_tr}.")
    print(f"Saved to {filepath_te}.")
