import os
import numpy as np
import torch
import sys
from src.utils.arrays import to_torch
import src.utils as utils
from src.data.trajectory import TrajectoryDataset, Q_TrajectoryDataset, H_TrajectoryDataset
from src.data.sequence import DiscretizedDataset
from src.models.decision_transformer import DecisionTransformerGPT
from src.models.bundled_transformer import BundledTransformerGPT
from src.models.bcq_transformer import BCQ_TransformerGPT
from src.models.gdt_transformer import GDT_TransformerGPT
from src.models.trajectory_transformer import TrajectoryTransformerGPT
from src.models.gpt import FullBlock
from src.splt_models.splt_transformer import SPLTTransformerGPT
import time
import wandb
import config.offline as offline
import random
from src.rl_models.iql import IQLTrainer
import argparse
import math
from math import log, e
import pickle

def argmax_sg(current_id, traj_rew): # find the subogal corresponding to current state
    max_reward = -1000
    argmax_id = -1000
    for index in range(current_id + 2, len(traj_rew)):
        accu_reward = np.mean(traj_rew[current_id+1: index])
        if accu_reward > max_reward:
            max_reward = accu_reward
            argmax_id = index-1
    if current_id in [len(traj_rew) -2, len(traj_rew) -1]:
        
        max_reward = traj_rew[-1]
        argmax_id = 39
    return max_reward, int(argmax_id)


def V_config_subgoal(obs_segmented, rewards_segmented, data, args): # store all the valid subgoals in a vector


    observation_mean=to_torch(np.mean(data['observations']), device=args.device)
    observation_std=to_torch(np.std(data['observations']), device=args.device)

    subgoals = []

    subgoals_segmented = np.ones((obs_segmented.shape[0], obs_segmented.shape[1], obs_segmented.shape[2])) *(-1000)
    sgv_segmented = np.ones((obs_segmented.shape[0], obs_segmented.shape[1], 1)) *(-1000)

    # print('subgoal shape: ', subgoals_segmented.shape) # num_traj, max_length, 1
    subgoals = []
    total_traj_rew = 0
    total_steps = 0
    for num_traj in range(obs_segmented.shape[0]):
        last_step = obs_segmented.shape[1]
        
        
        for index in range(obs_segmented.shape[1]):
          
            if rewards_segmented[num_traj, index, :][0] == -1000:
                last_step = index
                break
        total_steps += last_step


        obs_torch = to_torch(obs_segmented[num_traj, :last_step, :], device=args.device).reshape(-1, 1, 4)


        obs_torch = (obs_torch - observation_mean) / observation_std
        
        vfs = list(args.gpt.vf(obs_torch).detach().cpu().numpy().flatten())

       
        assert(len(vfs) == last_step)
        # print('sandra')

        pointer_v = 0
        start = 0


        iteration = 0

        while pointer_v < len(vfs) -1:
            
            pointer_v = np.argmax(vfs[start + 1:]) + start + 1

            for i in range(start, pointer_v):
                iteration += 1

                subgoals_segmented [num_traj, i, :] = obs_segmented[num_traj, pointer_v, :] 
                subgoals.append(obs_segmented[num_traj, pointer_v, :] )
            
            start = pointer_v

        subgoals_segmented [num_traj, len(vfs)-1, :] = obs_segmented[num_traj, len(vfs)-1, :] 
        subgoals.append(obs_segmented[num_traj, len(vfs)-1, :] )
        iteration += 1
        assert(len(vfs) == iteration)
    

    return np.array(subgoals)

def segment(observations, terminals, max_path_length, observation_dim): # build a matrix [num_traj, max_length] of each information (e.g., state, action, reward, subgoal)
    """
        segment `observations` into trajectories according to `terminals`
    """
    assert len(observations) == len(terminals)
   

    trajectories = [[]]
    curr_len = 0
    for obs, term in zip(observations, terminals):
        trajectories[-1].append(obs)
        curr_len += 1
        if term.squeeze() or (curr_len >= max_path_length):
            trajectories.append([])
            curr_len = 0
 

    if len(trajectories[-1]) == 0:
        trajectories = trajectories[:-1]

    trajectories = [np.stack(traj, axis=0) for traj in trajectories]


    n_trajectories = len(trajectories)
    path_lengths = [len(traj) for traj in trajectories]

    ## pad trajectories to be of equal length
    trajectories_pad = np.ones((n_trajectories, max_path_length, observation_dim), dtype=trajectories[0].dtype) * (-1000)
   
    for i, traj in enumerate(trajectories):
        path_length = path_lengths[i]
        trajectories_pad[i,:path_length, :] = traj.reshape(-1, observation_dim)
        

    return trajectories_pad, path_lengths

class Parser(utils.Parser):
    dataset: str = 'idm-uniform07'
    visiable: str = 'whole'
    config: str = 'config.offline'
    algo: str = 'DT' # DT, TT, SPLT, IQL, BC,

#######################
######## setup ########
#######################

def config_data(args, run): # select the offline dataset for training
    if args.dataset == 'stop80':
        artifact = run.use_artifact('your-folder/data_stop80:v0', type='dataset')
        wandb_dataset_dir = artifact.download()
        data_file = wandb_dataset_dir +'/data_stop80.npz'
    elif args.dataset == 'stop70':
        artifact = run.use_artifact('your-folder/data_stop70:v0', type='dataset')
        wandb_dataset_dir = artifact.download()
        data_file = wandb_dataset_dir +'/data_stop70.npz'
    elif args.dataset == 'stop60':
        artifact = run.use_artifact('your-folder/data_stop60:v0', type='dataset')
        wandb_dataset_dir = artifact.download()
        data_file = wandb_dataset_dir +'/data_stop60.npz'
    elif args.dataset == 'stop50':
        artifact = run.use_artifact('your-folder/data_stop50:v0', type='dataset')
        wandb_dataset_dir = artifact.download()
        data_file = wandb_dataset_dir +'/data_stop50.npz'
    elif args.dataset == 'stop40':
        artifact = run.use_artifact('your-folder/data_stop40:v0', type='dataset')
        wandb_dataset_dir = artifact.download()
        data_file = wandb_dataset_dir +'/data_stop40.npz'
    elif args.dataset == 'stop30':
        artifact = run.use_artifact('your-folder/data_stop30:v0', type='dataset')
        wandb_dataset_dir = artifact.download()
        data_file = wandb_dataset_dir +'/data_stop30.npz'
    elif args.dataset == 'stop_random':
        artifact = run.use_artifact('your-folder/data_stop_random:v0', type='dataset')
        wandb_dataset_dir = artifact.download()
        data_file = wandb_dataset_dir +'/data_stop_random.npz'

    with open(data_file, 'rb') as f:
        data = dict(np.load(f))
    number_partial = 0
    number = 0
    vis_threshold = 25 #36 # 47
    distance = []
    data['actions'] = data['actions'][:, None]
    data['terminals'] = data['dones']

    
    if args.algo in ['BCQ', 'GDT']:
        obs_segmented, *_ = segment(data['observations'], data['terminals'] , 40, 4)
        rewards_segmented, *_ = segment(data['rewards'], data['terminals'] , 40, 1)
        subgoals = V_config_subgoal(obs_segmented, rewards_segmented, data, args)

        data['subgoals'] = subgoals

    
    # further expriments for partial observation 
    if args.visiable == 'constant':
        print('[train/visiable] constant')
        for i in range(len(data['actions'])):
            number += 1
            if data['observations'][i, 2] - data['observations'][i, 0] >= vis_threshold:
                number_partial += 1
                data['observations'][i, 2] = data['observations'][i, 0] + vis_threshold
                data['observations'][i, 3] = data['observations'][i, 1]
    elif args.visiable == 'zero_value':
        print('[train/visiable] zero value')
        for i in range(len(data['actions'])):
            if data['observations'][i, 2] - data['observations'][i, 0] >= vis_threshold:
                data['observations'][i, 2] = 0
                data['observations'][i, 3] = 0

    elif args.visiable == 'random_noise':
        print('[train/visiable] random noise')
        for i in range(len(data['actions'])):
            if data['observations'][i, 2] - data['observations'][i, 0] >= vis_threshold:
                data['observations'][i, 2] =  vis_threshold + 9 * random.random()
                data['observations'][i, 3] =  10 * random.random()
    return data

def config_algo(args, run, dataset, obs_dim, act_dim, transition_dim, stats, block_size, reward_scale, data): # select algorithm

    if args.algo == 'DT':
        print('using DT')
        model_config = utils.Config(
            DecisionTransformerGPT,
            action_tanh=True,
            observation_mean=to_torch(stats['observation_mean'], device=args.device),
            observation_std=to_torch(stats['observation_std'], device=args.device),
            action_mean=to_torch(stats['action_mean'], device=args.device),
            action_std=to_torch(stats['action_std'], device=args.device),
            return_mean=to_torch(stats['return_mean'], device=args.device),
            return_std=to_torch(stats['return_std'], device=args.device),
            savepath=(args.savepath, 'model_config.pkl'),
            ## architecture
            block_size=block_size,
            n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd*args.n_head,
            ## dimensions
            observation_dim=obs_dim, action_dim=act_dim, transition_dim=transition_dim,
            ## dropout probabilities
            embd_pdrop=args.embd_pdrop, resid_pdrop=args.resid_pdrop, attn_pdrop=args.attn_pdrop,
        )
    
    elif args.algo == 'TT':
        print('using TT')
        model_config = utils.Config(
            #  GPT,
            TrajectoryTransformerGPT,
            savepath=(args.savepath, 'model_config.pkl'),
            ## discretization
            vocab_size=args.N, block_size=block_size,
            ## architecture
            n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd*args.n_head,
            ## dimensions
            observation_dim=obs_dim, action_dim=act_dim, transition_dim=transition_dim,
            ## loss weighting
            action_weight=args.action_weight, reward_weight=args.reward_weight, value_weight=args.value_weight,
            ## dropout probabilities
            embd_pdrop=args.embd_pdrop, resid_pdrop=args.resid_pdrop, attn_pdrop=args.attn_pdrop,
        )
    elif args.algo == 'GDT':
        model_config = utils.Config(
            GDT_TransformerGPT,
            observation_mean=to_torch(stats['observation_mean'], device=args.device),
            observation_std=to_torch(stats['observation_std'], device=args.device),
            subgoal_mean = to_torch(np.mean(data['subgoals']), device=args.device),
            subgoal_std = to_torch(np.std(data['subgoals']), device=args.device),
            action_mean=to_torch(stats['action_mean'], device=args.device),
            action_std=to_torch(stats['action_std'], device=args.device),
            res=False,
            action_tanh=True,
            savepath=(args.savepath, 'model_config.pkl'),
            ## architecture
            block_size=block_size,
            n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd*args.n_head,
            ## dimensions
            observation_dim=obs_dim, action_dim=act_dim, transition_dim=transition_dim,
          
            ## dropout probabilities
            embd_pdrop=args.embd_pdrop, resid_pdrop=args.resid_pdrop, attn_pdrop=args.attn_pdrop,
        )
    elif args.algo == 'BCQ':
        model_config = utils.Config(
            BCQ_TransformerGPT,
            observation_mean=to_torch(stats['observation_mean'], device=args.device),
            observation_std=to_torch(stats['observation_std'], device=args.device),
            subgoal_mean = to_torch(np.mean(data['subgoals']), device=args.device),
            subgoal_std = to_torch(np.std(data['subgoals']), device=args.device),
          
            savepath=(args.savepath, 'model_config.pkl'),
            ## architecture
            block_size=block_size,
            n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd*args.n_head,
            ## dimensions
            observation_dim=obs_dim, action_dim=act_dim, transition_dim=transition_dim,
         
            ## dropout probabilities
            embd_pdrop=args.embd_pdrop, resid_pdrop=args.resid_pdrop, attn_pdrop=args.attn_pdrop,
        )
    elif args.algo == 'BC':
        model_config = utils.Config(
            BundledTransformerGPT,
            observation_mean=to_torch(stats['observation_mean'], device=args.device),
            observation_std=to_torch(stats['observation_std'], device=args.device),
            action_mean=to_torch(stats['action_mean'], device=args.device),
            action_std=to_torch(stats['action_std'], device=args.device),
            reward_mean=to_torch(stats['reward_mean'], device=args.device),
            reward_std=to_torch(stats['reward_std'], device=args.device),
            value_mean=to_torch(stats['value_mean'], device=args.device),
            value_std=to_torch(stats['value_std'], device=args.device),
            res=False,
            action_tanh=True,
            savepath=(args.savepath, 'model_config.pkl'),
            ## architecture
            block_size=block_size,
            n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd*args.n_head,
            ## dimensions
            observation_dim=obs_dim, action_dim=act_dim, transition_dim=transition_dim,
            ## loss weighting
            action_weight=1.0,
            observation_weight=0.,
            reward_weight=0.,
            value_weight=0.,
            ## dropout probabilities
            embd_pdrop=args.embd_pdrop, resid_pdrop=args.resid_pdrop, attn_pdrop=args.attn_pdrop,
        )

    elif args.algo == 'IQL':
        print('using IQL')

        model_config = utils.Config(
            IQLTrainer,
            action_mean=to_torch(stats['action_mean'], device=args.device),
            action_std=to_torch(stats['action_std'], device=args.device),
            observation_mean=to_torch(stats['observation_mean'], device=args.device),
            observation_std=to_torch(stats['observation_std'], device=args.device),
            savepath=(args.savepath, 'model_config.pkl'),
            ## architecture
            block_size=block_size,
            n_hidden_layers=2,
            embedding_dim=256,
            reward_scale=reward_scale,
            quantile=0.7,
            discount=args.discount,
            soft_target_tau=5.e-3,
            alpha=10.,
            clip_score=100.,
            max_log_std=0,
            min_log_std=-6,
            ## dimensions
            observation_dim=obs_dim, action_dim=act_dim, transition_dim=transition_dim,
        )
   
    elif args.algo == 'SPLT':
        model_config = utils.Config(
            SPLTTransformerGPT,
            encoder_class=BundledTransformerGPT,
            decoder_class=BundledTransformerGPT,
            encoder_config={'block_class': FullBlock},
            beta=1.e-3,
            world_latent_dim=2,
            policy_latent_dim=3,
            res=True,
            action_tanh=False,
            observation_mean=to_torch(stats['observation_mean'], device=args.device),
            observation_std=to_torch(stats['observation_std'], device=args.device),
            observation_diff_mean=to_torch(stats['observation_diff_mean'], device=args.device),
            observation_diff_std=to_torch(stats['observation_diff_std'], device=args.device),
            action_mean=to_torch(stats['action_mean'], device=args.device),
            action_std=to_torch(stats['action_std'], device=args.device),
            reward_mean=to_torch(stats['reward_mean'], device=args.device),
            reward_std=to_torch(stats['reward_std'], device=args.device),
            value_mean=to_torch(stats['value_mean'], device=args.device),
            value_std=to_torch(stats['value_std'], device=args.device),
            savepath=(args.savepath, 'model_config.pkl'),
            ## architecture
            block_size=block_size,
            n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd*args.n_head,
            ## dimensions
            observation_dim=obs_dim, action_dim=act_dim, transition_dim=transition_dim,
            ## loss weighting
            action_weight=args.action_weight, reward_weight=args.reward_weight, value_weight=args.value_weight,
            ## dropout probabilities
            embd_pdrop=args.embd_pdrop, resid_pdrop=args.resid_pdrop, attn_pdrop=args.attn_pdrop,
        )
    model = model_config()
    model.to(args.device)

    artifact = wandb.Artifact('model_config', type='config')
    artifact.add_file(args.savepath+ '/model_config.pkl',  'model_config.pkl')
    run.log_artifact(artifact)

    return model

def config_trainer(args, run, block_size, dataset): # build either transformered or rl based trainer


    transformer_algo = ['DT', 'TT', 'BC', 'SPLT', 'BCQ', 'GDT']
    if args.algo in transformer_algo:
        warmup_tokens = len(dataset) * block_size ## number of tokens seen per epoch
        final_tokens = 20 * warmup_tokens

        trainer_config = utils.Config(
        utils.Trainer,
        savepath=(args.savepath, 'trainer_config.pkl'),
        # optimization parameters
        batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        betas=(0.9, 0.95),
        grad_norm_clip=1.0,
        weight_decay=0.1, # only applied on matmul weights
        # learning rate decay: linear warmup followed by cosine decay to 10% of original
        lr_decay=args.lr_decay,
        warmup_tokens=warmup_tokens,
        final_tokens=final_tokens,
        ## dataloader
        num_workers=4,
        device=args.device,
    )
   

    elif args.algo =='IQL':
        trainer_config = utils.Config(
        utils.RLTrainer,
        savepath=(args.savepath, 'trainer_config.pkl'),
        # optimization parameters
        batch_size=args.batch_size,
        learning_rate=3.e-4,
        betas=(0.9, 0.95),
        weight_decay=0., # only applied on matmul weights
        ## dataloader
        num_workers=4,
        device=args.device,
    )

    trainer = trainer_config()

    artifact = wandb.Artifact('trainer_config', type='config')
    artifact.add_file(args.savepath+ '/trainer_config.pkl',  'trainer_config.pkl')
    run.log_artifact(artifact)

    return trainer




def experiment(config_sweep): # select training name and its parser
    # if args_setting in ['DT', 'BC', 'BCQ', 'GDT'] : # relabel subgoal DT
    #     name = 'bc_train'
    # elif args_setting in ['SPLT', 'IQL']:
    #     name = 'train'
    # elif args_setting == 'TT':
    #     name = 'tt_train'

    name = 'bc_train'

    # # BCQ
    # data_index = [907,  648, 651, 655, 657 ]
    # model_index = [1298,  992, 995, 999, 1001 ]
    # state_index = [179,  125, 124, 126, 125 ]
    # stop_sign = 30


    # IQL 
    # data_index = [657, 648, 651, 655, 657 ]
    # model_index = [1001, 992, 995, 999, 1001 ]
    # state_index = [131, 125, 124, 126, 125 ]
    # stop_sign = 30


    # data_index = [856, 856, 651, 655, 657 ]  # no dinscount
    # model_index = [1217, 1219, 995, 999, 1001 ]
    # state_index = [170, 165, 124, 126, 125 ]
    # stop_sign = 30


    data_index = [646, 650, 652, 656, 658 ]
    model_index = [990, 994, 996, 1000, 1002 ]
    state_index = [133, 126, 125, 127, 126 ]
    stop_sign = 40

    # data_index = [598, 603, 609, 616, 623 ]
    # model_index = [955, 960, 966, 973, 979 ]
    # state_index = [120, 115, 114, 116, 115 ]
    # stop_sign = 50

    # data_index = [599, 604, 610, 618, 625 ]
    # model_index = [956, 961, 967, 975, 981 ]
    # state_index = [121, 116, 116, 118, 118 ]
    # stop_sign = 60

    # data_index = [600, 605, 613, 619, 627 ]
    # model_index = [957, 962, 970, 976, 983 ]
    # state_index = [122, 118, 118, 120, 119 ]
    # stop_sign = 70

    # data_index = [601, 606, 611, 617, 624 ]
    # model_index = [958, 963, 968, 974, 980 ]
    # state_index = [123, 117, 115, 117, 116 ]
    # stop_sign = 80

    # data_index = [602, 607, 614, 620, 626 ]
    # model_index = [959, 964, 971, 977, 982 ]
    # state_index = [124, 119, 117, 119, 117 ]
    # stop_sign = 80



    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    args = Parser().parse_args(name)

    split_path = args.savepath.split('/')
    if args.algo == 'DT':
        args.savepath = '/'.join(split_path[:1] + ['dt'] + split_path[1:])
        args.exp_name = 'dt/' + args.exp_name
    elif args.algo == 'BCQ':
        args.savepath = '/'.join(split_path[:1] + ['bcq'] + split_path[1:])
        args.exp_name = 'bcq/' + args.exp_name
    elif args.algo == 'GDT':
        args.savepath = '/'.join(split_path[:1] + ['gdt'] + split_path[1:])
        args.exp_name = 'gdt/' + args.exp_name
    elif args.algo == 'TT':
        args.savepath = '/'.join(split_path[:1] + ['tt'] + split_path[1:])
        args.exp_name = 'tt/' + args.exp_name
    elif args.algo == 'BC':
        args.savepath = '/'.join(split_path[:1] + ['bt'] + split_path[1:])
        args.exp_name = 'bt/' + args.exp_name
    elif args.algo == 'SPLT':
        args.savepath = '/'.join(split_path[:1] + ['splt_bt'] + split_path[1:])
        args.exp_name = 'splt_bt/' + args.exp_name
    elif args.algo == 'IQL':
        args.savepath = '/'.join(split_path[:1] + ['iql'] + split_path[1:])
        args.exp_name = 'iql/' + args.exp_name
        args.subsampled_sequence_length = 2
  

    utils.serialization.mkdir(args.savepath)

    config_sweep['device'] = device
    config_sweep['discount'] = args.discount
    config_sweep['n_layer'] = args.n_layer
    config_sweep['n_head'] = args.n_head
    config_sweep['n_embd'] = args.n_embd
    config_sweep['n_epochs_ref'] = args.n_epochs_ref
    config_sweep['logbase'] = args.logbase
    config_sweep['batch_size'] = args.batch_size
    config_sweep['learning_rate'] = args.learning_rate
    config_sweep['lr_decay'] = args.lr_decay
    config_sweep['embd_pdrop'] = args.embd_pdrop
    config_sweep['resid_pdrop'] = args.resid_pdrop
    config_sweep['attn_pdrop'] = args.attn_pdrop
    config_sweep['step'] = args.step
    config_sweep['subsampled_sequence_length'] = args.subsampled_sequence_length
    config_sweep['termination_penalty'] = args.termination_penalty
    config_sweep['action_weight'] = args.action_weight
    config_sweep['reward_weight'] = args.reward_weight
    config_sweep['value_weight'] = args.value_weight
    config_sweep['visiable'] = args.visiable






    run = wandb.init(sync_tensorboard=False, config=config_sweep, name=args.exp_name,  #wandb.config
    save_code=True)




    args.seed = list(wandb.config['param']['seed'].values())[0]

    random.seed(args.seed)
    np.random.seed(args.seed)

    print('current seed: ', args.seed)


    if args.algo in ['BCQ', 'GDT']: # ['IQDT', 'IQSPLT', 'BCQ']: IQL, ['GDT']: bcq
        args.index = list(wandb.config['param']['iql_index'].values())[0]
        model_artifact = run.use_artifact('your-folder/model_config:v'+str(model_index[args.index]), type='config')
        config_path = model_artifact.download()
        config = pickle.load(open(config_path+'/model_config.pkl', 'rb'))

        state_artifact = run.use_artifact('your-folder/state_48_seed'+str(args.seed)+':v'+str(state_index[args.index]), type='model')
        state_path = state_artifact.download()
        state = torch.load(state_path+'/state_48_seed'+str(args.seed)+'.pt')

        gpt = config()
        gpt.to(args.device)
        gpt.load_state_dict(state, strict=True)
        args.gpt = gpt
        print(f'\n[ utils/serialization ] Loaded config from {config_path}\n')
        print(config)

    #######################
    ####### dataset #######
    #######################
    
    data = config_data(args, run)



    sequence_length = args.subsampled_sequence_length * args.step # 10 * 1

    if args.algo == 'TT':
        dataset_config = utils.Config(
        DiscretizedDataset,
        savepath=(args.savepath, 'data_config.pkl'),
        env=None,
        dataset=data,
        N=args.N,
        penalty=args.termination_penalty,
        sequence_length=sequence_length,
        step=args.step,
        discount=args.discount,
        discretizer=args.discretizer,
        max_path_length=40,
        timeouts=False,
    )

    elif args.algo in [ 'BCQ', 'GDT']:
        print('bcq: ', data['subgoals'].shape)
        dataset_config = utils.Config(
            H_TrajectoryDataset,
            savepath=(args.savepath, 'data_config.pkl'),
            env=None,
            dataset=data,
            penalty=args.termination_penalty,
            sequence_length=sequence_length,
            step=args.step,
            discount=args.discount,
            max_path_length=40,
            timeouts=False,
        )


    else:
        dataset_config = utils.Config(
            TrajectoryDataset,
            savepath=(args.savepath, 'data_config.pkl'),
            env=None,
            dataset=data,
            penalty=args.termination_penalty,
            sequence_length=sequence_length,
            step=args.step,
            discount=args.discount,
            max_path_length=40,
            timeouts=False,
        )
    artifact = wandb.Artifact('data_config', type='config')
    artifact.add_file(args.savepath + '/data_config.pkl',  'data_config.pkl')
    run.log_artifact(artifact)

    dataset = dataset_config()
    obs_dim = dataset.observation_dim
    act_dim = dataset.action_dim


   
    reward_scale = 1. # only for IQL

    if args.algo in ['DT', 'GDT']:
        transition_dim = 3
    elif args.algo == 'BCQ':
        transition_dim = 1
    elif args.algo == 'TT':
        transition_dim = dataset.joined_dim
    elif args.algo == 'SPLT' or args.algo == 'BC':
        transition_dim = 2
    elif args.algo == 'IQL':
        transition_dim = obs_dim + act_dim # 5 = 4 + 1



    stats = dataset.get_stats()


  
    ######## model ########
    #######################

    block_size = args.subsampled_sequence_length * transition_dim - 1
    print(
        f'Dataset size: {len(dataset)} | '
        f'Joined dim: {transition_dim} '
        f'(observation: {obs_dim}, action: {act_dim}) | Block size: {block_size}'
    )



    model = config_algo(args, run, dataset, obs_dim, act_dim, transition_dim, stats, block_size, reward_scale, data)
    print('config model')

    #######################
    ####### trainer #######
    #######################

    # print('data subgoals ', data['subgoals'].shape)

    trainer = config_trainer(args, run, block_size, dataset)
    print('config trainer')
    #######################
    ###### main loop ######
    #######################

    ## scale number of epochs to keep number of updates constant
    start_time = time.time()
    n_epochs = int((1e6 / len(dataset) * args.n_epochs_ref))
    print('n_epochs: ', n_epochs, args.n_epochs_ref)

    for epoch in range(n_epochs):


        print(f'\nEpoch: {epoch} / {n_epochs} | {args.dataset} | {args.exp_name}')

        trainer.train(model, dataset)

        if (epoch % 10 == 0) or (epoch == n_epochs -1):

            statepath = os.path.join(args.savepath, f'state_{epoch}.pt')
            print(f'Saving model to {statepath}')

            ## save state to disk
            state = model.state_dict()
            torch.save(state, statepath)
            artifact = wandb.Artifact(f'state_{epoch}_seed{args.seed}' , type='model')
            artifact.add_file(statepath, f'state_{epoch}_seed{args.seed }.pt')
            run.log_artifact(artifact)



    print('total training time: ', time.time() - start_time)


if __name__ == '__main__':
    config_sweep = offline.sweep_config
    sweep_id = wandb.sweep(config_sweep, project = offline.base['WANDB_PROJECT'])
    config_sweep['sweep_id'] = sweep_id
    wandb.agent(sweep_id, function = lambda: experiment(config_sweep = config_sweep),  project = offline.base['WANDB_PROJECT'], count = 5)
    print('end sweeping')

