import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import os
import logging
import random
import csv
import ipdb
import pandas as pd

from utils import set_seed_everywhere

import argparse
import matplotlib.pyplot as plt

from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer, discrete_action_match_scorer

from typing import Any, Callable, Iterator, List, Optional, Tuple, Union, cast
from d3rlpy.metrics.scorer import AlgoProtocol, evaluate_on_environment
from d3rlpy.preprocessing.stack import StackedObservation

from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from sklearn.model_selection import train_test_split
from active_cql_traffic import ActiveCQL
from pbrl_traffic import PBRL
from d3rlpy.dataset import MDPDataset

import wandb
import d3rlpy
import gym
import gym_minigrid
from gym_minigrid.wrappers import ImgObsWrapper, FlatObsWrapper, FullyObsWrapper, FlatImageObsWrapper, RGBImgObsWrapper
from d3rlpy.envs import ChannelFirst
from torch.optim.lr_scheduler import CosineAnnealingLR

from array2gif import write_gif
from PIL import Image

def evaluate_on_benchmark_environments_viz(
    envs: List[gym.Env],
    n_trials: int = 2,
    epsilon: float = 0.0,
    render: bool = False,
    gif: bool = False,
    path: Optional[str] = None,
) -> Callable[..., float]:

    # for image observation
    observation_shape = envs[0].observation_space.shape
    is_image = len(observation_shape) == 3

    frames = []

    def scorer(algo: AlgoProtocol, *args: Any) -> float:
        across_env_reward = []
        rewards_dict = {}
        for env in envs:

            episode_rewards = []
            for trial_itr in range(n_trials):
                observation = env.reset()
                episode_reward = 0.0
                episode_len = 0

                while True:
                    env.render('human')
                    if gif:
                        frames.append(np.moveaxis(env.render("rgb_array"), 2, 0))
                    if np.random.random() < epsilon:
                        action = env.action_space.sample()
                    else:
                        if is_image:
                            action = algo.predict([observation])[0]
                        else:
                            action = algo.predict([observation])[0]

                    observation, reward, done, _ = env.step(action)
                    episode_reward += reward
                    episode_len  +=1
    
                    if done:
                        break


                episode_rewards.append(episode_reward)
                print(env.nickname + '- ep len: ', episode_len)
                print(env.nickname + '- ep rew: ', episode_reward)

                if gif:
                    print("Saving gif... ", end="")
                    write_gif(np.array(frames), algo._active_logger.logdir + '/' + env.nickname + '_ep' + str(algo.epoch) + '_' + str(trial_itr) + ".gif", fps=3)
                    print("Done: ", trial_itr)
                frames.clear()

            # wandb.log(
            #     {
            #         "reward_" + env.nickname: float(np.mean(episode_reward)),
            #     },
            #     step=algo.epoch
            # )
            across_env_reward.append(np.mean(episode_reward)) 
            rewards_dict[env.nickname] = np.mean(episode_reward)
        return rewards_dict

    return scorer

def save_imgs(redstates_dict, redstates_count, step):
    for img_hash in redstates_dict.keys():
        img = Image.fromarray(redstates_dict[img_hash])
        img.save('gifs/runtime/' + str(step) + '_' + str(redstates_count[img_hash]) + '_' + str(img_hash) + ".png")


def main(args):

    data_type = args.dataset_type
    env_data = args.env_data 
    
    
    add_merge = ''
    grid_states = np.load('data/' + env_data + '/' + add_merge + data_type +  '/states.npy')
    rgb_states = np.load('data/' + env_data + '/' + add_merge + data_type +  '/rgb_states.npy')
    # states = np.load('data/' + env_data + '/' + add_merge + data_type +  '/partial_states.npy')
    actions = np.load('data/' + env_data + '/' + add_merge + data_type +  '/actions.npy')
    rewards = np.load('data/' + env_data + '/' + add_merge + data_type +  '/rewards.npy')
    dones = np.load('data/' + env_data + '/' + add_merge + data_type +  '/dones.npy')


    grid_states_yellow = np.load('data/' + env_data + '/expert_plus_yellow/states.npy')
    rgb_states_yellow = np.load('data/' + env_data + '/expert_plus_yellow/rgb_states.npy')
    actions_yellow = np.load('data/' + env_data + '/expert_plus_yellow/actions.npy')
    rewards_yellow = np.load('data/' + env_data + '/expert_plus_yellow/rewards.npy')
    dones_yellow = np.load('data/' + env_data + '/expert_plus_yellow/dones.npy')

    if False: # args.prune_yellow:
        red_lists = []
        terminal_indices = (np.arange(dones.shape[0])[dones.astype(int)==1]+1)[:-1]
        
        grid_state_eps = np.split(grid_states, terminal_indices)
        rgb_state_eps = np.split(rgb_states, terminal_indices)
        done_eps = np.split(dones, terminal_indices)
        reward_eps = np.split(rewards, terminal_indices)
        action_eps = np.split(actions, terminal_indices)

        grid_state_list = []
        rgb_state_list = []
        done_list = []
        reward_list = []
        action_list = []

        # where tile is yellow and action is go forward
        for idx, (ep_split, action_split) in enumerate(zip(grid_state_eps, action_eps)):
            mask = (ep_split[:, 0, 0, 1]==4) & (action_split[:, 0]==2)
            if mask.sum()>0:
                # print(mask.sum())
                red_lists.append(idx)
            else:
                grid_state_list.append(grid_state_eps[idx])
                rgb_state_list.append(rgb_state_eps[idx])
                done_list.append(done_eps[idx])
                reward_list.append(reward_eps[idx])
                action_list.append(action_eps[idx])

        # concatenate
        grid_states = np.concatenate(grid_state_list)
        rgb_states = np.concatenate(rgb_state_list)
        dones = np.concatenate(done_list)
        rewards = np.concatenate(reward_list)
        actions = np.concatenate(action_list)

        print('yellow list size: ', len(red_lists))
        print(red_lists)

    if True: # no_add_yellow flag was here previously
        rgb_states = np.concatenate((rgb_states, rgb_states_yellow[18:36]))
        grid_states = np.concatenate((grid_states, grid_states_yellow[18:36]))
        actions = np.concatenate((actions, actions_yellow[18:36].reshape(-1, 1)))
        rewards = np.concatenate((rewards, rewards_yellow[18:36]))
        dones = np.concatenate((dones, dones_yellow[18:36]))

    if args.no_rewards_offset:
        rewards = rewards + 0.001
    agent_tl_mask = grid_states[:, 8, 2, 0]==10
    print('agent in light:', agent_tl_mask.sum())
    tl_mask = grid_states[:, :, :, 0]==11.0
    tl_states = grid_states[:, :, :, 1] * tl_mask
    tl_states = tl_states.sum(axis=-1).sum(axis=-1)
    red_mask = (tl_states==0)
    green_mask = (tl_states==1)
    print('reds: ', red_mask.sum())
    print('greens: ', green_mask.sum())

    assert len(grid_states) == len(actions) == len(rewards) == len(dones), "data length mismatch"

    print("states shape: ", grid_states.shape)
    if not args.confused:
        rgb_states[:, 0:8, 0:8] = np.array([100, 100, 100])
    
    states = rgb_states.transpose(0, 3, 1, 2)

    meta_states = grid_states.transpose(0, 3, 1, 2)
    meta_states = np.ascontiguousarray(meta_states)
    states= np.ascontiguousarray(states)

    dataset = MDPDataset(
        states,
        meta_states,
        actions,
        rewards,
        dones,
        discrete_action=True)


    name = args.wandb_name
    name = name + '_ncr_' + str(args.n_critics)

    wandb.init(
        group=name,
        job_type=str(args.seed),
        project="workshop_traffic_rl",
        entity="causalsampling",
        config=args,
        mode=args.wandb_mode
    )
    wandb.run.name = name + '_seed' + str(args.seed) 
    wandb.run.save()

    # set_seed_everywhere(args.seed)
    # d3rlpy.seed(args.seed)
    set_seed_everywhere(args.seed)
    d3rlpy.seed(args.seed)


    env_names = [
        'MiniGrid-Simple-No-Traffic-No-Switch-Red-v0',
        'MiniGrid-Simple-No-Traffic-No-Switch-Confusion-Green-v0',
        'MiniGrid-Simple-Stop-Agent-Switch-v0',
        # normal envs
        # green with nothing else
        'MiniGrid-Simple-No-Traffic-No-Switch-v0',

    ]
    eval_envs = []
    for env_name in env_names:

        eval_env = gym.make(env_name)

        eval_env = FullyObsWrapper(eval_env)
        eval_env = ChannelFirst(ImgObsWrapper(RGBImgObsWrapper(eval_env)))
        eval_envs.append(eval_env)

    train_episodes = dataset.episodes
    test_episodes = dataset.episodes 
    
    if args.prune_yellow:

        pruned = []
        # # 92, 108, 125, 133,
        # yellow_eps = [192, 239, 447, 466, 285, 303, 317, 350, 364, 400, 410, 48, 61, 467, 478, 524, 544]
        # for ep_instance in train_episodes:
        #     if ep_instance.ep_id in yellow_eps:
        #         continue
        #     else:
        #         pruned.append(ep_instance)  
        # train_episodes = pruned
        yellow_eps = [92, 108, 125, 133, 192, 239, 447, 466, 285, 303, 317, 350, 364, 400, 410, 48, 61, 467, 478, 524, 544]
        for ep_instance in train_episodes:
            if ep_instance.ep_id in yellow_eps:
                pruned.append(ep_instance)
            else:
                for i in range(args.data_multiply):
                    pruned.append(ep_instance)
        random.shuffle(pruned)
        train_episodes = pruned


    if args.method=='active':
        cql = ActiveCQL(
                n_critics=args.n_critics,       ## use 5 for active
                encoder_factory='default',
                use_gpu=True,
                batch_size=args.batch_size,
                optim_factory=d3rlpy.models.optimizers.AdamFactory(eps=1e-2 / 32),
                alpha=args.alpha_cql,
                learning_rate=args.lr,
                scaler='pixel',
                target_update_interval=args.target_update_interval,
            )
    elif method=='pbrl':
        cql = PBRL(
                n_critics=args.n_critics,       ## use 5 for active
                encoder_factory='default',
                use_gpu=True,
                batch_size=args.batch_size,
                optim_factory=d3rlpy.models.optimizers.AdamFactory(eps=1e-2 / 32),
                alpha=args.alpha_cql,
                learning_rate=args.lr,
                scaler='pixel',
                target_update_interval=args.target_update_interval,
            )


    if args.decay_lr:
        transition = train_episodes[0]
        action_size = transition.get_action_size()
        observation_shape = tuple(transition.get_observation_shape())
        cql.create_impl(
            cql._process_observation_shape(observation_shape), action_size
        )   

        scheduler = CosineAnnealingLR(cql.impl._optim, 6000)

        def callback(algo, epoch, total_step):
            scheduler.step()

    results = cql.fit(train_episodes,
            eval_episodes=test_episodes,
            save_interval=args.save_interval,
            n_epochs=args.epochs if not args.do_by_steps else None,
            n_steps=args.epochs*args.n_steps_per_epoch if args.do_by_steps else None,
            n_steps_per_epoch=args.n_steps_per_epoch if args.do_by_steps else None,
            scorers={
                'environment': evaluate_on_benchmark_environments_viz(eval_envs, gif=args.make_gif),
                'advantage': discounted_sum_of_advantage_scorer, # smaller is better
                'td_error': td_error_scorer, # smaller is better
                'value_scale': average_value_estimation_scorer, # smaller is better
            },
            net_args=args,
            experiment_name=name + '_seed' + str(args.seed),              ##
            callback=callback if args.decay_lr else None,
    )



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_type', type=str, default='expert', choices=['expert', 'mixed'])
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--save_interval', type=int, default=200)

    parser.add_argument('--aq_size', type=int, default=20)     
    parser.add_argument('--batch_size', type=int, default=4096)     
    parser.add_argument('--n_critics', type=int, default=1)
    parser.add_argument('--target_update_interval', type=int, default=4)

    parser.add_argument('--n_steps_per_epoch', type=int, default=100)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument("--lr", default=1e-4, type=float)

    parser.add_argument("--down_factor", default=1.0, type=float)
    parser.add_argument("--alpha_cql", default=4.0, type=float)
    parser.add_argument("--decay_lr", action="store_true", default=False)
    parser.add_argument('--wandb_mode', choices=['dryrun', 'dryrun_offline', 'online'], default='dryrun')
    parser.add_argument('--wandb_name', default='')
    parser.add_argument("--env_data", default="MiniGrid-Simple-Stop-Light-Rarely-Switch-v0", type=str)
    parser.add_argument("--sampler_manual_curr", action="store_true", default=False)


    parser.add_argument("--indep_ensemble", action="store_true", default=False)
    parser.add_argument("--prune_yellow", action="store_true", default=False)
    parser.add_argument("--no_rewards_offset", action="store_true", default=False)  
    parser.add_argument("--use_target_action", action="store_true", default=False)
    parser.add_argument("--share_encoder", action="store_true", default=False) 
    parser.add_argument("--bootstrap_ens", action="store_true", default=False)
    parser.add_argument("--gradual_period", default=1, type=int)


    parser.add_argument("--do_by_steps", action="store_true", default=False)
    parser.add_argument("--datapath", default="/users/abcdef/code/activesampling/", type=str)    
    parser.add_argument('--aqfunc', choices=['tdper', 'random', 'mu_realadv_atari',
        'manual', 'mu_indepadv_atari', 'mu_both_atari', 'bald', 'relo_tdper', 'mu_meanadv_atari', 'mu_max_atari', 'mu_mult_atari', 'mu_mean_atari', 'mu_combo_atari'], default='random')
    parser.add_argument("--active_weights_power", default=1.0, type=float)
    parser.add_argument("--num_bkwds", default=1, type=int)
    parser.add_argument("--make_gif", action="store_true", default=False)
    parser.add_argument("--log_loss_on_all_data", action="store_true", default=False)
    parser.add_argument("--flooding_loss", action="store_true", default=False)
    parser.add_argument("--manual_curr_coeff", default=10, type=int)
    parser.add_argument("--tdper_with_cons", action="store_true", default=False)
    parser.add_argument("--blur_traj", action="store_true", default=False)

    parser.add_argument("--clip_grad", default=1.0, type=float)
    parser.add_argument("--episodic", action="store_true", default=False)
    parser.add_argument("--select_eps", default=40, type=int)
    parser.add_argument("--downweighting", action="store_true", default=False)
    parser.add_argument("--augment", action="store_true", default=False)
    parser.add_argument("--traffic", action="store_true", default=False)

    parser.add_argument("--confused", action="store_true", default=False) 
    parser.add_argument("--data_multiply", default=1, type=int)
    parser.add_argument('--method', type=str, default='active', choices=['active', 'pbrl']) # 'active_il', 'il'
    parser.add_argument("--interpolate", default=0.5, type=float)

    args = parser.parse_args()
    main(args)

