import os
import time

import gym
import numpy as np
import pickle
import time

import torch
from tqdm import tqdm
import wandb
import os
import smart_settings

np.set_printoptions(precision=3, suppress=True)

from src.smodice_pytorch import SMODICE
from src.discriminator_pytorch import Discriminator_SA
import src.utils as utils

np.set_printoptions(precision=3, suppress=True)

MUJOCO = ['hopper', 'walker2d', 'halfcheetah', 'ant']


def run(config):
    ########################################
    # Setup
    
    # load checkpoints if existing
    ckpt, pretrain_ckpt = utils.maybe_load_last_checkpoint()
        
    ############### ENVIRONMENTS AND DATA ###########################
    
    version = 'v2'
    offline_dataset, expert_dataset = None, None
    actions_max, actions_min = None, None
    if config.env_name != 'solo12':
        if 'kitchen' in config['env_name']:
            version  = 'v0'
        # Load environment
        if not config['mismatch']:
            env = gym.make(f"{config['env_name']}-{config['dataset']}-{version}")
        else:      
            env = gym.make(f"{config['env_name']}-random-{version}")

        if config['env_name'] not in MUJOCO:
            if config['env_name'] != 'kitchen':
                expert_env = gym.make(f"{config['env_name']}-{config['dataset']}-{version}")
            else:
                expert_env = gym.make(f"kitchen-complete-v0")
        else:
            print('Making env expert...',f"{config['env_name']}-expert-{version}" )
            expert_env = gym.make(f"{config['env_name']}-expert-{version}")
        # Seeding
        np.random.seed(config['seed'])
        torch.manual_seed(config['seed']) 
        env.seed(config['seed'])
        expert_env.seed(config['seed'])
    elif config.env_name == 'solo12':
        expert_env, expert_dataset, offline_dataset, actions_max, actions_min = utils.get_solo_env_and_data(config)
        env = expert_env
    else:
        raise NotImplementedError


    # Load expert dataset
    if not config['mismatch']:
        print(f'Loading expert dataset...')
        traj_iterator = utils.sequence_dataset(expert_env, dataset=expert_dataset)
        expert_traj = next(traj_iterator)
    else:
        # Load mismatch expert dataset 
        demo_file = f"envs/demos/{config['env_name']}_{config['dataset']}.pkl"
        demo = pickle.load(open(demo_file, 'rb'))
        if 'ant' in config['env_name']: 
            expert_obs = np.array(demo['observations'][:1000])
            expert_actions = np.array(demo['actions'][:1000])
            expert_next_obs = np.array(demo['next_observations'][:1000])
        else:
            expert_obs = np.array(demo['observations'][0])
            expert_actions = np.array(demo['actions'][0])
            expert_next_obs = np.array(demo['next_observations'][0])
        expert_traj = {'observations': expert_obs, 'actions': expert_actions, 'next_observations': expert_next_obs}

    # Load offline dataset
    if config['num_expert_traj'] == 0:
        initial_obs_dataset, dataset, dataset_statistics = utils.dice_dataset(env, standardize_observation=config['standardize_obs'], absorbing_state=config['absorbing_state'], 
                                                                              standardize_reward=config['standardize_reward'], dataset=offline_dataset)
    else:
        initial_obs_dataset, dataset, dataset_statistics = utils.dice_combined_dataset(expert_env, env, num_expert_traj=config['num_expert_traj'], 
                                                                                       num_offline_traj=config['num_offline_traj'],
    standardize_observation=config['standardize_obs'], absorbing_state=config['absorbing_state'],
     standardize_reward=config['standardize_reward'], offline_dataset=offline_dataset, expert_dataset=expert_dataset)

    # Normalize expert observations and potentially add absorbing state
    if config['standardize_obs']:
        expert_obs_dim = expert_traj['observations'].shape[1]
        expert_traj['observations'] = (expert_traj['observations'] - dataset_statistics['observation_mean'][:expert_obs_dim]) / (dataset_statistics['observation_std'][:expert_obs_dim] + 1e-10)
        if 'next_observations' in expert_traj:
            expert_traj['next_observations'] = (expert_traj['next_observations'] - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10)
    if config['absorbing_state']:
        expert_traj = utils.add_absorbing_state(expert_traj)
    if config['use_policy_entropy_constraint'] or config['use_data_policy_entropy_constraint']:
        if config['target_entropy'] is None:
            config['target_entropy'] = -np.prod(env.action_space.shape)
    
    
    # Create inputs for the discriminator
    state_dim = dataset_statistics['observation_dim'] + 1 if config['absorbing_state'] else dataset_statistics['observation_dim']
    action_dim = 0 if config['state'] else dataset_statistics['action_dim']

    # Policy, value and classifier see disc_idxs, discriminator sees div_idxs
    disc_idxs = list(range(state_dim)) if config['disc_idxs'] == 'all' else config['disc_idxs']

    expert_input = expert_traj['observations'][:, disc_idxs]
    offline_input = dataset['observations'] # don't butcher, because we'll use it differently for diversity and classifier


    ########################################################
    ################## CREATE AGENT ########################
    ########################################################
    
    print(f'Creating {config["algo_type"]} agent!')
    if 'smodice' == config['algo_type']:
        agent = SMODICE(
            observation_spec=dataset_statistics['observation_dim'] + 1 if config['absorbing_state'] else dataset_statistics['observation_dim'],
            action_spec=dataset_statistics['action_dim'], config=config
        )
    else:
        raise Exception('Agent not available!')

    #########################################################
    ##################### PRETRAINING #######################
    #########################################################


    print('disc cutoff', disc_idxs, action_dim, config['hidden_sizes'][0])
    discriminator = Discriminator_SA(disc_idxs, action_dim, hidden_dim=config['hidden_sizes'][0], device=config['device'], grad_pen_coef=config.grad_pen_coef)
    # Train discriminator
    dataset_expert = torch.utils.data.TensorDataset(torch.FloatTensor(expert_input))    
    dataset_offline = torch.utils.data.TensorDataset(torch.FloatTensor(offline_input))
    expert_loader = torch.utils.data.DataLoader(dataset_expert, batch_size=config.pretrain_batch_size, shuffle=True, pin_memory=True)
    offline_loader = torch.utils.data.DataLoader(dataset_offline, batch_size=config.pretrain_batch_size, shuffle=True, pin_memory=True)
    if config.compile_models:
        discriminator = torch.compile(discriminator)
    
    if pretrain_ckpt is None:
        if config['disc_type'] == 'learned' and not config.debug:
            for i in tqdm(range(config['disc_iterations'])):
                loss = discriminator.update(expert_loader, offline_loader)
                if i%100 == 0:
                    wandb.log({'pretrain/classifier_loss': loss, 'pretrain_iter': i})
    else:
        discriminator.load_state_dict(pretrain_ckpt['discriminator_state_dict'])




    from smodice.networks_pytorch import TanhNormalPolicy
    # learn behavior policy
    obs_act_dataset = torch.utils.data.TensorDataset(torch.FloatTensor(dataset['observations']), torch.FloatTensor(dataset['actions']))
    data_loader = torch.utils.data.DataLoader(obs_act_dataset, batch_size=config.pretrain_batch_size, shuffle=True, pin_memory=True)
    behavior_policy  = None
    
  

    ################# PRETRAIN CHECKPOINTS ###########################
    result_logs = []
    start_iteration = 0
    
    # default_models_path = os.path.join(f'/is/rg/al/Data/doi/{config.env_name}_{config.dataset}_{config.num_expert_traj}_pretrain_last.pth.tar')
    default_models_path = os.path.join(f'{config.working_dir}/{config.env_name}_{config.dataset}_{config.num_expert_traj}_pretrain_last.pth.tar')
    if pretrain_ckpt is None or config.pretrained_models_path == 'save':
        print('Saving pretrain checkpoint...')
        ckpt_ = {
            'discriminator_state_dict': discriminator.state_dict(),
        }
        utils.save_checkpoint(ckpt_, 'pretrain_last.pth.tar')
        if not os.path.exists(default_models_path) or config.pretrained_models_path == 'save':
            torch.save(ckpt_, default_models_path)  

    if ckpt:
        agent.load_state_dict(ckpt['state_dict'])
        start_iteration = ckpt['iteration']+1
    # Start training
    start_time = time.time()
    last_start_time = time.time()


    ####################################
    ######### PRECOMPUTE REWARDS #######
    ####################################


    def _precompute_classifier_rewards():
        rewards = []
        for i in range(0, dataset_statistics['N'], 2048):
            last_idx = min(dataset_statistics['N'], i + 2048)
            observation = dataset['observations'][i:last_idx]
            # Get rewards
            with torch.no_grad():
                obs_for_disc = torch.from_numpy(np.array(observation)).to(discriminator.device)
                if config['disc_type'] == 'zero':
                    reward = torch.zeros_like(reward)
                else:
                    if config['state']:
                        disc_input = obs_for_disc
                    reward = discriminator.predict_reward(disc_input)

                rewards.append(reward.cpu().numpy())
        rewards = np.concatenate(rewards)
        # normalize reward
        if config['normalize_classifier_reward']:
            rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
        return rewards



    # Precompute classifier rewards
    classifier_rewards = _precompute_classifier_rewards()
    assert classifier_rewards.shape[0] == dataset_statistics['N']
    dataset['classifier_rewards'] = classifier_rewards


    # initialize data loader
    random_loader = utils.SmodiceLoader(config.batch_size, 
                                              dataset=dataset,
                                              initial_obs_dataset=initial_obs_dataset,
                                              dataset_statistics=dataset_statistics,
                                              shuffle=True)

    #####################################
    ######## ORIGINAL SAMPLER ###########
    #####################################

    def _sample_minibatch(batch_size, reward_scale):
        initial_indices = np.random.randint(0, dataset_statistics['N_initial_observations'], batch_size)
        indices = np.random.randint(0, dataset_statistics['N'], batch_size)
        sampled_dataset = (
            initial_obs_dataset['initial_observations'][initial_indices],
            dataset['observations'][indices],
            dataset['actions'][indices],
            dataset['rewards'][indices] * reward_scale,
            dataset['next_observations'][indices],
            dataset['terminals'][indices],
            dataset['experts'][indices]
        )
        return tuple(map(torch.from_numpy, sampled_dataset))


    for iteration in tqdm(range(start_iteration, config['total_iterations'] + 1), ncols=70, desc='DOI', initial=start_iteration, total=config['total_iterations'] + 1, ascii=True, disable=os.environ.get("DISABLE_TQDM", False)):
        
        
        batch = _sample_minibatch(config.batch_size, 0.1)
        batch = {
            'initial_observations': batch[0],
            'observations': batch[1],
            'actions': batch[2],
            'rewards': batch[3],
            'next_observations': batch[4],
            'terminals': batch[5],
            'experts': batch[6]
        }

        # Get rewards
        with torch.no_grad():
            obs_for_disc = torch.from_numpy(np.array(batch['observations'])).to(discriminator.device)
            if config['state']:
                disc_input = obs_for_disc
            else:
                act_for_disc = torch.from_numpy(np.array(batch['actions'])).to(discriminator.device)
                disc_input = torch.cat([obs_for_disc, act_for_disc], axis=1)
            reward = discriminator.predict_reward(disc_input)



        # Perform gradient descent
        max_steps = 280 if 'kitchen' in config['env_name'] else None 
        
        train_result = agent.train_step(batch['initial_observations'], 
                                        batch['observations'], 
                                        batch['actions'],
                                        reward, 
                                        batch['next_observations'], 
                                        batch['terminals'])
        # log with wandb
        train_result['train_iter'] = iteration
        if iteration % config.wandb_log_interval == 0:
            log_dict = { f"train/{k}": v for k,v in  train_result.items()}
            log_dict['train/train_iter'] = iteration
            wandb.log(log_dict)
            

        # Logging
        if iteration % config['log_iterations'] == 0:
            for k in train_result:
                if isinstance(train_result[k], torch.Tensor):
                    train_result[k] = train_result[k].cpu().detach().numpy()
            
            eval = utils.evaluate(env, agent, dataset_statistics, absorbing_state=config['absorbing_state'],  iteration=iteration, max_steps=max_steps)
            
            train_result['eval/task_reward'] = eval[0]
            train_result['eval/task_std'] = eval[1]

            # compute the important-weights for expert vs. offline data
            expert_index = (batch['experts']==1).nonzero(as_tuple=False)
            offline_index = (batch['experts']==0).nonzero(as_tuple=False)
            if 'w_e' in train_result:
                w_e = train_result['w_e']
                w_e_expert = w_e[expert_index.cpu()].mean()
                w_e_offline = w_e[offline_index.cpu()].mean()
                w_e_ratio = w_e_expert / w_e_offline
                w_e_overall = w_e.mean()
                importance_weights = {'w_e': w_e_overall, 'w_e_expert': w_e_expert, 'w_e_offline': w_e_offline, 'w_e_ratio': w_e_ratio}
            train_result.update({'iter_per_sec': config['log_iterations'] / (time.time() - last_start_time)})
            del train_result['w_e'] # heavy to save
            w_e, e_v = agent.compute_all_w_e(random_loader)
            
            agent_state = agent.get_state_dict()

            ckpt = {
                'state_dict': agent_state,
                'iteration': iteration,
                'w_e': w_e,
                'e_v': e_v,
                'classifier_rewards': classifier_rewards
            }


            utils.save_checkpoint(ckpt, 'last_ckpt.pth.tar')
            if config.pretrained_models_path == 'save':
                #  torch.save(ckpt, f'/is/rg/al/Data/doi/{config.env_name}_{config.dataset}_{config.num_expert_traj}_smodice_last.pth.tar')
                torch.save(ckpt, f'{config.working_dir}/{config.env_name}_{config.dataset}_{config.num_expert_traj}_smodice_last.pth.tar')

            wandb.log(dict(**train_result, **importance_weights))
            result_logs.append({'log': train_result, 'step': iteration})
            
            if not int(os.environ.get('DISABLE_STDOUT', 0)):
                print(f'=======================================================')

                if train_result.get('eval'):
                    print(f'- {"eval":23s}:{train_result["eval"]:15.10f}')
                print(f'iteration={iteration} (elapsed_time={time.time() - start_time:.2f}s, {train_result["iter_per_sec"]:.2f}it/s)')
                print(f'=======================================================', flush=True)

            last_start_time = time.time()
            start_time = time.time()

    

    print('Training finished')
    return {'done':  1.0}

from smodice.utils import Ticker
import sys
if __name__ == "__main__":
    config = smart_settings.load(sys.argv[1])
    
    Ticker(config) # has global state
    Ticker.device = 'cuda'

    metrics = run(config)
    wandb.finish(0)