import argparse
import os
import pickle
import random
import torch
from ctrls.ctrl_darkroom import (
    DarkroomTransformerController
)
from utils import (
    build_darkroom_data_filename,
    build_darkroom_model_filename,
)
from net import Transformer
import gym
import numpy as np
import time
import common_args
from dpt_envs import darkroom_env
from utils import (

    build_darkroom_data_filename,
   
)




def rollin_mdp(env, rollin_type, controller = None):
    states = []
    actions = []
    next_states = []
    rewards = []

    state = env.reset()
    for _ in range(env.horizon):
        if controller != None:
            action = controller.act(state)
        else:
            if rollin_type == 'uniform':
                state = env.sample_state()
                action = env.sample_action()
            elif rollin_type == 'expert':
                action = env.opt_action(state)
        
            else:
                raise NotImplementedError
        next_state, reward = env.transit(state, action)

        states.append(state)
        actions.append(action)
        next_states.append(next_state)
        rewards.append(reward)
        state = next_state

    states = np.array(states)
    actions = np.array(actions)
    next_states = np.array(next_states)
    rewards = np.array(rewards)

    # print('states: ', states.shape, actions.shape, next_states.shape, rewards.shape)
    # (100, 2) (100, 5) (100, 2) (100,)


    return states, actions, next_states, rewards



def generate_mdp_histories_from_envs(envs, n_hists, n_samples, rollin_type, controller = None):
    trajs = []
    print('n_samples: ', n_samples)
    print('n_hists: ', n_hists)
    print('rollin type: ', rollin_type)
    print('len envs: ', len(envs))
    for env in envs:
        for j in range(n_hists):
            (
                context_states,
                context_actions,
                context_next_states,
                context_rewards,
            ) = rollin_mdp(env, rollin_type=rollin_type, controller = controller)
            for k in range(n_samples):
                query_state = env.sample_state()

                optimal_action = env.opt_action(query_state)
                # print('query_state: ', query_state, optimal_action, env.goal)

                traj = {
                    'query_state': query_state,
                    'optimal_action': optimal_action,
                    'context_states': context_states,
                    'context_actions': context_actions,
                    'context_next_states': context_next_states,
                    'context_rewards': context_rewards,
                    'goal': env.goal,
                }
                # Add perm_index for DarkroomEnvPermuted
                if hasattr(env, 'perm_index'):
                    traj['perm_index'] = env.perm_index

                trajs.append(traj)
        if len(trajs) % 500 == 0:
            print('traj: ', len(trajs))

    return trajs



def generate_darkroom_ood_histories(goals, dim, horizon, **kwargs):
 
    envs = []
    probabilities = [0.01, 0.05, 0.1, 0.15, 0.2]
    robust = True
    #alphas = [[1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 10.0], [5.0, 1.0, 5.0, 1.0, 1.0], [1.0, 5.0, 1.0, 5.0, 1.0]] # darkroom_ood
    # alphas = [[1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 10.0], [5.0, 1.0, 5.0, 1.0, 1.0], [3.0, 3.0, 3.0, 3.0, 1.0], [3.0, 3.0, 3.0, 1.0, 3.0], [3.0, 3.0, 1.0, 3.0, 3.0], [3.0, 1.0, 3.0, 3.0, 3.0], [1.0, 3.0, 3.0, 3.0, 3.0]] #darkroom_ood8
    alphas = [[1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 10.0], [5.0, 1.0, 5.0, 1.0, 1.0], [3.0, 3.0, 3.0, 3.0, 1.0], [3.0, 3.0, 3.0, 1.0, 3.0], [3.0, 3.0, 1.0, 3.0, 3.0], [3.0, 1.0, 3.0, 3.0, 3.0], [1.0, 3.0, 3.0, 3.0, 3.0],
    [5.0, 1.0, 1.0, 1.0, 5.0], [1.0, 5.0, 1.0, 1.0, 5.0], [1.0, 1.0, 5.0, 1.0, 5.0], [1.0, 1.0, 1.0, 5.0, 5.0], [4.0, 4.0, 4.0, 1.0, 1.0], [4.0, 4.0, 1.0, 4.0, 1.0], [4.0, 1.0, 4.0, 4.0, 1.0], [1.0, 4.0, 4.0, 4.0, 1.0]] #darkroom_ood8

    if len(goals) == 100:
        goal_id = 0
        for alpha in alphas:
            for prob in probabilities:
                envs += [darkroom_env.DarkroomEnv(dim, goal, horizon, prob, robust, alpha) for goal in goals[goal_id:goal_id+5]]
                goal_id += 5
                print('len: ', len(envs), goal_id)
        
    else:
        for alpha in alphas:
            for prob in probabilities:

                envs += [darkroom_env.DarkroomEnv(dim, goal, horizon, prob, robust, alpha) for goal in goals]
                print('len: ', len(envs))

    
    trajs = generate_mdp_histories_from_envs(envs, **kwargs)
    return trajs


def generate_darkroom_histories(goals, dim, horizon, controller, **kwargs):
    prob = 0.0
    robust = False
    envs = [darkroom_env.DarkroomEnv(dim, goal, horizon, prob, robust) for goal in goals]
    trajs = generate_mdp_histories_from_envs(envs, controller = controller, **kwargs)
    return trajs





if __name__ == '__main__':
    np.random.seed(0)
    random.seed(0)

    parser = argparse.ArgumentParser()
    common_args.add_dataset_args(parser)
    args = vars(parser.parse_args())
    print("Args: ", args)

    env = args['env']
    envname = args['env']
    n_envs = args['envs']
    n_eval_envs = args['envs_eval']
    n_hists = args['hists']
    n_samples = args['samples']
    horizon = args['H']
    dim = args['dim']
    var = args['var']
    cov = args['cov']
    env_id_start = args['env_id_start']
    env_id_end = args['env_id_end']
    lin_d = args['lin_d']
    dataset_coeff = args['dataset_coeff'] # 20

    if env == 'darkroom_mix':
        n_train_envs = int(.8 * 26667)
        n_test_envs = 26667 - n_train_envs
    else:
        n_train_envs = int(.8 * n_envs)
        n_test_envs = n_envs - n_train_envs

    config = {
        'n_hists': n_hists,
        'n_samples': n_samples,
        'horizon': horizon,
       
    }
    start_time = time.time()
  

    if env == 'darkroom_ood' or env == 'darkroom_mix' or env == 'darkroom_ood8' or env == 'darkroom_ood16' or env == 'darkroom_id': #darkroom_ood: 4

        config.update({'dim': dim, 'rollin_type': 'uniform'})
        goals = np.array([[(j, i) for i in range(dim)]
                         for j in range(dim)]).reshape(-1, 2)
        np.random.RandomState(seed=0).shuffle(goals)
        train_test_split = int(.8 * len(goals))
        train_goals = goals[:train_test_split]
        test_goals = goals[train_test_split:]
        

        if env == 'darkroom_ood' or env == 'darkroom_ood8' or env == 'darkroom_ood16':
            eval_goals = np.array(test_goals.tolist() *
                                int(100 // (len(test_goals))))
                    
        
            train_goals = np.repeat(train_goals, n_envs // (dim * dim * dataset_coeff), axis=0) 
            test_goals = np.repeat(test_goals, n_envs // (dim * dim * dataset_coeff), axis=0)
        
      

        elif env == 'darkroom_id' or env == 'darkroom_mix': 
            eval_goals = np.array(test_goals.tolist() *
                                int(100 // len(test_goals)))  
            
            if env == 'darkroom_mix':
                train_goals = np.repeat(train_goals, 26667 // (dim * dim), axis=0)
                test_goals = np.repeat(test_goals, 26667 // (dim * dim), axis=0)
            else:
                train_goals = np.repeat(train_goals, n_envs // (dim * dim), axis=0)
                test_goals = np.repeat(test_goals, n_envs // (dim * dim), axis=0)
        else:
            eval_goals = np.array(test_goals.tolist() *
                                int(100 // len(test_goals)))  
             
            train_goals = np.repeat(train_goals, n_envs // (dim * dim), axis=0)
            test_goals = np.repeat(test_goals, n_envs // (dim * dim), axis=0)
        
        print('len eval goals: ', eval_goals.shape)

       
        if env == 'darkroom_ood' or env == 'darkroom_ood8' or env == 'darkroom_ood16':
            train_trajs = generate_darkroom_ood_histories(train_goals, **config)
            test_trajs = generate_darkroom_ood_histories(test_goals, **config)
            eval_trajs = generate_darkroom_ood_histories(eval_goals, **config)
        elif env == 'darkroom_id' or env == 'darkroom_mix':
            if env == 'darkroom_mix':
                model_config = {
                    'shuffle': True,
                    'lr': 0.001,
                    'dropout': 0,
                    'n_embd': 32,
                    'n_layer': 4,
                    'n_head': 4,
                    'n_envs': n_envs,
                    'n_hists': 1,
                    'n_samples': 1,
                    'horizon': 100,
                    'dim': 10,
                    'seed': 0,
                }
                transformer_config = {
                    'horizon':horizon,
                    'n_embd':32,
                    'n_layer':4,
                    'n_head':4,
                    'state_dim':2,
                    'action_dim':5,
                    'dropout':0,
                    'test': True
                }
                filename = build_darkroom_model_filename(envname, model_config)
                model = Transformer(transformer_config).to('cuda:0')
    
                tmp_filename = filename
                epoch = 50

                model_path = f'models/{tmp_filename}_epoch{epoch}.pt'
                
    
                checkpoint = torch.load(model_path)
                model.load_state_dict(checkpoint)
                model.eval()
                batch = {
                    'context_states': torch.zeros((1, 1, 2)).float().to('cuda:0'),#(num_envs, -1, vec_env.state_dim),
                    'context_actions': torch.zeros((1, 1, 5)).float().to('cuda:0'),
                    'context_next_states': torch.zeros((1, 1, 2)).float().to('cuda:0'),
                    'context_rewards': torch.zeros((1, 1, 1)).float().to('cuda:0')
                }
      

                
                controller = DarkroomTransformerController(model, batch_size=1, sample=True)
                controller.set_batch(batch)
            else:
                controller = None
            
            train_trajs = generate_darkroom_histories(train_goals, controller=controller, **config)
            test_trajs = generate_darkroom_histories(test_goals, controller=controller, **config)
            eval_trajs = generate_darkroom_histories(eval_goals, controller=controller, **config)
        else:
            NotImplementedError

        print('len eval trajs: ', len(eval_trajs))

        train_filepath = build_darkroom_data_filename(
            env, n_envs, config, mode=0)
        test_filepath = build_darkroom_data_filename(
            env, n_envs, config, mode=1)
        eval_filepath = build_darkroom_data_filename(env, 100, config, mode=2)

        

    elif env == 'darkroom_heldout':

        config.update({'dim': dim, 'rollin_type': 'expert'})
        goals = np.array([[(j, i) for i in range(dim)]
                         for j in range(dim)]).reshape(-1, 2)
        np.random.RandomState(seed=0).shuffle(goals)
        train_test_split = int(.8 * len(goals))
        train_goals = goals[:train_test_split]
        test_goals = goals[train_test_split:]

        eval_goals = np.array(test_goals.tolist() *
                              int(100 // len(test_goals)))
        train_goals = np.repeat(train_goals, n_envs // (dim * dim), axis=0)
        test_goals = np.repeat(test_goals, n_envs // (dim * dim), axis=0)

        train_trajs = generate_darkroom_histories(train_goals, controller=None,**config)
        test_trajs = generate_darkroom_histories(test_goals, controller=None, **config)
        eval_trajs = generate_darkroom_histories(eval_goals, controller=None, **config)

        train_filepath = build_darkroom_data_filename(
            env, n_envs, config, mode=0)
        test_filepath = build_darkroom_data_filename(
            env, n_envs, config, mode=1)
        eval_filepath = build_darkroom_data_filename(env, 100, config, mode=2)


    else:
        raise NotImplementedError


    if not os.path.exists('datasets'):
        os.makedirs('datasets', exist_ok=True)
    with open(train_filepath, 'wb') as file:
        pickle.dump(train_trajs, file)
    with open(test_filepath, 'wb') as file:
        pickle.dump(test_trajs, file)
    with open(eval_filepath, 'wb') as file:
        pickle.dump(eval_trajs, file)

    print('data collection time: ', time.time() - start_time)