


from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import asdict, dataclass
import os
import sys
from pathlib import Path
import random
import uuid

from copy import deepcopy

import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
import time
from datetime import datetime

from torch.distributions import Normal
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import TanhTransform


from datasets import load_data
import utils

import logging
import pandas as pd

import algos.roil.ReverseBC as ReverseBC
import algos.roil as roil

import algos.roil.BCQ as BCQ

import algos.bc as bc

from visual_maze2d import draw_data

from sklearn.neighbors import NearestNeighbors
from sklearn.ensemble import IsolationForest

# from algos.td3bc import TD3_BC
from algos.oil import TD3_BC


TensorBatch = List[torch.Tensor]

TEST_LAST_STEPS = 10








def train_method(args):



    data_e, data_o, env = load_data.get_offline_imitation_data(args.expert_data, args.offline_data, 
                                                          expert_num=args.expert_num, offline_exp=args.offline_exp)
    
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    

    replay_buffer_e = utils.ReplayBuffer(state_dim, action_dim, args.buffer_size, args.device)
    replay_buffer_o = utils.ReplayBuffer(state_dim, action_dim, args.buffer_size, args.device)


    replay_buffer_e.convert_D3RL(data_e)
    replay_buffer_o.convert_D3RL(data_o)



    if args.normalize:

        observations_all = np.concatenate([replay_buffer_e.state, replay_buffer_o.state]).astype(np.float32)

        # print(observations_all.shape)

        state_mean = np.mean(observations_all, 0)
        state_std = np.std(observations_all, 0) + 1e-3
    else:
        state_mean, state_std = 0, 1

    replay_buffer_e.normalize_states(mean=state_mean, std=state_std)
    replay_buffer_o.normalize_states(mean=state_mean, std=state_std)
    

        
    env = utils.wrap_env(env, state_mean=state_mean, state_std=state_std)


    if args.log_path is not None:

        config_path = os.path.join(args.log_path, args.name)

        print(f"Config path: {config_path}")

        os.makedirs(config_path, exist_ok=True)

        with open(os.path.join(config_path, "args.yaml"), "w") as f:
            pyrallis.dump(args, f)


        print(f"Checkpoint path: {args.checkpoints_path}")
        os.makedirs(args.checkpoints_path, exist_ok=True)


    max_action = float(env.action_space.high[0])

    # Set seeds
    seed = args.seed
    utils.set_seed(seed, env)

    

    max_action = float(env.action_space.high[0])
    
    kwargs = {
        "state_dim": state_dim, 
        "action_dim": action_dim, 
        "max_action": max_action,
        "lr": args.lr, 
        "wd": args.wd, 
        "discount": args.discount,
        "tau": args.tau,
        "device": args.device,
        # TD3
        "policy_noise": args.policy_noise * max_action,
        "noise_clip": args.noise_clip * max_action,
        "policy_freq": args.policy_freq,
        "bc_freq": args.bc_freq, 
        # TD3 + BC
        "alpha": args.alpha,
    }

    print("---------------------------------------")
    print(f"Training TD3+BC  Data: {args.offline_data} , Seed: {seed}")



    print("Expert:  {} transitation".format(replay_buffer_e.size))

    print("")

    print("Offline:  {} transitation".format(replay_buffer_o.size))
    print("kwargs: ", kwargs)
    print("---------------------------------------")
    # Initialize policy

    # print(args.stochastic_policy)

    # Initialize policy
    policy = TD3_BC(**kwargs)

    print("learn TD3+BC on expert + offline buffer")

    # replay_buffer_e.reward = np.ones_like(replay_buffer_e.reward) * 2. 

    # buffers = [(replay_buffer_e, .5), (model_replay_buffer, .5)]


    # buffers = utils.BufferEnsemble([(replay_buffer_e, .2), (model_replay_buffer, .8)])

    # buffers = utils.BufferEnsemble([(replay_buffer_e, .5), (model_replay_buffer, .5)])



    # Set Reward
    replay_buffer_e.reward = np.ones_like(replay_buffer_e.reward)
    replay_buffer_o.reward = np.zeros_like(replay_buffer_o.reward)

    # buffers = [(replay_buffer_e, 1. - args.model_ratio), (replay_buffer_o, args.model_ratio)]
    
    # buffers = utils.BufferEnsemble(buffers)
    
    # buffers = replay_buffer_o
    
    evaluations = []
    best_d4rl_score = 0


    training_iters = 0

    while training_iters < args.max_timesteps:


        # print('Train step:', training_iters)
        log_dict = policy.train(replay_buffer_e, replay_buffer_o, iterations=int(args.eval_freq), batch_size=args.batch_size)

        training_iters += args.eval_freq
        print(f"Training iterations: {training_iters}/{args.max_timesteps}:")

        actor_loss = np.mean(log_dict["actor_loss"])
        critic_loss = np.mean(log_dict["critic_loss"])

        actor_loss_q = np.mean(log_dict["actor_loss_q"])
        actor_loss_bc = np.mean(log_dict["actor_loss_bc"])

        print("actor_loss: ", actor_loss, "    critic_loss: ", critic_loss)
        print("actor_loss_q: ", actor_loss_q, "    actor_loss_bc: ", actor_loss_bc)


        eval_scores = utils.eval_actor(env, policy.actor, device=args.device, eval_episodes=args.eval_episodes, seed=args.seed)
        eval_score = eval_scores.mean()
        normalized_eval_score = env.get_normalized_score(eval_score) * 100.0
        
        evaluations.append(normalized_eval_score)
        
        print("Reverse QL agent: ")
        print(f"Evaluation over {args.eval_episodes} episodes: "
        f"{eval_score:.3f} , D4RL score: {normalized_eval_score:.3f}")
        if args.mode != "debug":
            wandb.log(
                    {
                    "actor_loss_train": actor_loss, 
                    "critic_loss_train": critic_loss, 
                    # "actor_loss_train": train_loss, 
                    # "actor_loss_val": val_loss, 
                    "d4rl_normalized_score": normalized_eval_score, 
                    # "recorded_d4rl_score": best_d4rl_score
                    },
                    step=training_iters,
                )

    
    print("finish training")


    ckp_path = os.path.join(args.checkpoints_path, METHOD)
    policy.save(ckp_path)
    print("save model >>> ", ckp_path)
    

    print("--------------       results     -----------")

    # print("evaluations: ", evaluations)

    least_table = utils.summary_table(evaluations[-10: ])

    print(least_table)


    least_table.to_csv(os.path.join(args.log_path, "least_test_seed{}.csv".format(args.seed)))

    print("save results >>>", os.path.join(args.log_path, "least_test_seed{}.csv".format(args.seed)))

    print("least result mean +/- std: {} +/- {}".format(least_table["mean"][0], least_table["std"][0]))

    print(least_table)

    if args.mode != "debug":

        wandb.log(
                    {"least_evaluations": wandb.Table(dataframe=least_table)},
                )
        
        wandb.run.summary["least_evaluations_average"] = least_table["mean"][0]

        wandb.finish()

    return kwargs



import argparse




if __name__ == "__main__":


    CURRENT_TIME = str(datetime.now()).replace(" ", "-")

    parser = argparse.ArgumentParser()


    # Wandb logging
    parser.add_argument('--project', type=str, default='Offline-Imitation-Learning')  
    parser.add_argument('--group', type=str, default='BC-MuJoCo')



    # exp - data

    parser.add_argument('--expert-data', type=str, default='maze2d-large-expert-v1')  
    parser.add_argument('--offline-data', type=str, default='maze2d-large-v1')
    parser.add_argument('--expert-num', type=int, default=10, help='number of expert episodes')
    parser.add_argument('--offline-exp', type=int, default=0, help='number of expert episodes in offline data')


    # parser.add_argument('--ensemble_checkpoint', type=str, default=None)


    parser.add_argument('--model_ratio', type=float, default=0.5, help='')

    # exp - TD3+BC

    parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
    parser.add_argument('--wd', type=float, default=0.005, help='weight decay')

    parser.add_argument('--max_timesteps', type=int, default=500000, help='')


    parser.add_argument("--discount", type=float, default=0.99)  # Discount factor

    
    parser.add_argument("--expl_noise", type=float, default=0.1)  
    parser.add_argument("--tau", type=float, default=0.005)  # Target network update rate
    
    parser.add_argument("--policy_noise", type=float, default=0.2)  
    parser.add_argument("--noise_clip", type=float, default=0.5)  
    parser.add_argument("--policy_freq", type=int, default=1)  
    parser.add_argument("--bc_freq", type=int, default=1)
    # TD3 + BC

    # parser.add_argument("--alpha", type=float, default=2.5)  
    parser.add_argument("--alpha", type=float, default=1.0)  
    parser.add_argument('--normalize', type=bool, default=True, help='Normalize states')

    

    parser.add_argument('--batch_size', type=int, default=256, help='Batch size for all networks')
    # parser.add_argument('--discount', type=float, default=0.99, help='Discount factor')
    parser.add_argument('--buffer-size', type=int, default=10_000_000, help='Replay buffer size')

    parser.add_argument('--seed', type=int, default=0)


    # exp -testing
    parser.add_argument('--eval-episodes', type=int, default=10, help='How many episodes run during evaluation')
    parser.add_argument('--eval_freq', type=int, default=5000, help='How often (time steps) we evaluate')


    # logging
    parser.add_argument('--log-path', type=str, default='logs')  
    parser.add_argument('--load-model', type=str, default="")


    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--mode', type=str, default='debug')

    args = parser.parse_args()


    # METHOD = "ROIL{}-BC".format("-" + args.reverse_mode)
    
    METHOD = "OIL-TD3BC"

    # exp_name = f"{args.offline_data}"

    exp_name = f"{args.expert_data}-[{args.expert_num}]-{args.offline_data}"

    args.group = exp_name

    args.name = f"{METHOD}-seed[{args.seed}]-time-{CURRENT_TIME}"  
    
    args.log_path = os.path.join("logs", args.group)
    os.makedirs(args.log_path, exist_ok=True)

    args.log_path = os.path.join("logs", args.group, METHOD)
    os.makedirs(args.log_path, exist_ok=True)
    
    args.checkpoints_path = os.path.join(args.log_path, f"checkpoints-seed[{args.seed}]")

    log_file = os.path.join(args.log_path, args.mode + "-" + f"{METHOD}-seed[{args.seed}]-time-{CURRENT_TIME}" + ".txt")
    sys.stdout = utils.TextLogger(filename=log_file)
    

    # logging.basicConfig(filename=os.path.join(args.log_path, args.name + ".txt"))

    # print(args)

    if args.mode != "debug":
        utils.wandb_init(vars(args))
        
    if args.mode in ["train", "debug"]:
        train_method(args)
    
    if args.mode != "debug":
        wandb.finish()
    # test(args)

    # sys.stdout.close()