import numpy as np
import torch
import wandb
import torch.nn as nn
from typing import Dict, List, Union, Tuple, Optional
import argparse
import random
import os
import gym
import d4rl

from offlinerlkit.nets import MLP
from offlinerlkit.modules import ActorProb, Critic, TanhDiagGaussian
from offlinerlkit.utils.noise import GaussianNoise
# from offlinerlkit.utils.load_dataset import qlearning_dataset
from buffer import qlearning_dataset, dice_dataset, ReplayBuffer
from offlinerlkit.utils.scaler import StandardScaler
# from offlinerlkit.buffer import ReplayBuffer
from offlinerlkit.utils.logger import Logger, make_log_dirs
# from offlinerlkit.policy_trainer import MFPolicyTrainer
from mf_policy_trainer import MFPolicyTrainer
import suboptimal_offline_datasets


from osd_cql import OSDCQLPolicy

class ValueNetwork(nn.Module):
    def __init__(self, backbone: nn.Module,  device: str = "cpu", is_pos_output: bool=False) -> None:
        super().__init__()

        self.device = torch.device(device)
        self.backbone = backbone.to(device)
        latent_dim = getattr(backbone, "output_dim")
        self.last = nn.Linear(latent_dim, 1).to(device)
        self.is_pos_output = is_pos_output


    def forward(
        self,
        obs: Union[np.ndarray, torch.Tensor],
        actions: Optional[Union[np.ndarray, torch.Tensor]] = None
    ) -> torch.Tensor:
        obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
        if actions is not None:
            actions = torch.as_tensor(actions, device=self.device, dtype=torch.float32).flatten(1)
            obs = torch.cat([obs, actions], dim=1)
        # print('obs.shape', obs.shape)
        logits = self.backbone(obs)
        values = self.last(logits)
        if self.is_pos_output:
            values = torch.square(values)
        return values


    def load_params(self, state_dict):
        for key in self.state_dict().keys():
            self.state_dict()[key] = state_dict[key]
        # for name, param in state_dict.named_parameters():
        #     print(name,param.mean())
        
        
        
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--algo-name", type=str, default="osd-cql")
    parser.add_argument("--task", type=str, default="hopper-random-medium-0.5-v2")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--hidden-dims", type=int, nargs='*', default=[256, 256, 256])
    parser.add_argument("--actor-lr", type=float, default=1e-4)
    parser.add_argument("--critic-lr", type=float, default=3e-4)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--tau", type=float, default=0.005)
    parser.add_argument("--alpha", type=float, default=0.2)
    parser.add_argument("--target-entropy", type=int, default=None)
    parser.add_argument("--auto-alpha", default=True)
    parser.add_argument("--alpha-lr", type=float, default=1e-4)

    parser.add_argument("--cql-weight", type=float, default=5.0)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--max-q-backup", type=bool, default=False)
    parser.add_argument("--deterministic-backup", type=bool, default=True)
    parser.add_argument("--with-lagrange", type=bool, default=False)
    parser.add_argument("--lagrange-threshold", type=float, default=10.0)
    parser.add_argument("--cql-alpha-lr", type=float, default=3e-4)
    parser.add_argument("--num-repeat-actions", type=int, default=10)
    
    parser.add_argument("--epoch", type=int, default=1500)
    parser.add_argument("--step-per-epoch", type=int, default=1000)
    parser.add_argument("--eval_episodes", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--device", type=str, default="cuda:3" if torch.cuda.is_available() else "cpu")

    parser.add_argument("--is_osd", type=bool, default=True)
    parser.add_argument("--osd_alpha", type=float, default=1.0)
    parser.add_argument("--osd_beta", type=float, default=1e-3)
    parser.add_argument("--weight_type", type=str, default='median')
    parser.add_argument("--osd_batch_size", type=int, default=512)
    parser.add_argument('--osd_hidden_sizes', default=(256, 256))
    parser.add_argument('--lower', type=float, default=0.1)
    parser.add_argument('--higher', type=float, default=10.)
    parser.add_argument("--osd-lr", type=float, default=3e-4)
    parser.add_argument("--warmup_epoch", type=int, default=500)
    parser.add_argument("--use_wandb", type=int, default=0)

    return parser.parse_args()


def train(args=get_args()):
    # create env and dataset
    env = gym.make(args.task)
    # traj_returns, threshhold = pick_datasets(env)
    dataset = qlearning_dataset(env)
    # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22
    if 'antmaze' in args.task:
        dataset.rewards = (dataset.rewards - 0.5) * 4.0
    args.obs_shape = env.observation_space.shape
    args.action_dim = np.prod(env.action_space.shape)
    args.max_action = env.action_space.high[0]

    # seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    env.seed(args.seed)

    nu_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.osd_hidden_sizes)
    nu_network= ValueNetwork(nu_backbone, args.device)
    lam_v = torch.zeros(1, requires_grad=True, device=args.device)
    nu_optimizer = torch.optim.Adam(nu_network.parameters(), lr=args.osd_lr)
    lam_v_optimizer = torch.optim.Adam([lam_v], lr=args.osd_lr)


    # create policy model
    actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims)
    critic1_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims)
    critic2_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims)
    dist = TanhDiagGaussian(
        latent_dim=getattr(actor_backbone, "output_dim"),
        output_dim=args.action_dim,
        unbounded=True,
        conditioned_sigma=True
    )
    actor = ActorProb(actor_backbone, dist, args.device)
    critic1 = Critic(critic1_backbone, args.device)
    critic2 = Critic(critic2_backbone, args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
    critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
    critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

    if args.auto_alpha:
        target_entropy = args.target_entropy if args.target_entropy \
            else -np.prod(env.action_space.shape)

        args.target_entropy = target_entropy

        log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
        alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
        alpha = (target_entropy, log_alpha, alpha_optim)
    else:
        alpha = args.alpha

    # create policy
    policy = OSDCQLPolicy(
        actor,
        critic1,
        critic2,
        actor_optim,
        critic1_optim,
        critic2_optim,
        nu_network=nu_network,
        nu_optimizer = nu_optimizer,
        lam_v = lam_v,
        lam_v_optimizer =lam_v_optimizer,
        is_osd = args.is_osd,
        osd_alpha=args.osd_alpha,
        lower = args.lower,
        higher = args.higher,
        osd_beta = args.osd_beta,
        weight_type = args.weight_type,
        action_space=env.action_space,
        tau=args.tau,
        gamma=args.gamma,
        alpha=alpha,
        cql_weight=args.cql_weight,
        temperature=args.temperature,
        max_q_backup=args.max_q_backup,
        deterministic_backup=args.deterministic_backup,
        with_lagrange=args.with_lagrange,
        lagrange_threshold=args.lagrange_threshold,
        cql_alpha_lr=args.cql_alpha_lr,
        num_repeart_actions=args.num_repeat_actions
    )

    # create buffer
    buffer = ReplayBuffer(
        buffer_size=len(dataset["observations"]),
        obs_shape=args.obs_shape,
        obs_dtype=np.float32,
        action_dim=args.action_dim,
        action_dtype=np.float32,
        device=args.device
    )
    buffer.load_dataset(dataset)

    # log
    log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args))
    # key: output file name, value: output handler type
    output_config = {
        "consoleout_backup": "stdout",
        "policy_training_progress": "csv",
        "tb": "tensorboard"
    }
    logger = Logger(log_dirs, output_config)
    logger.log_hyperparameters(vars(args))


    # if args.use_wandb:
    #     wandb.init(project="iclr-osd-cql", group='osd-cql-update',
    #     name='seed-'+str(args.seed)+'-alpha-'+str(args.osd_alpha)+'-'+str(args.lower)+'-'+str(args.higher)+'-'+args.task, config=args)

    if args.use_wandb:
        wandb.init(project="sun", group='osd-cql',
        name='seed-'+str(args.seed)+'-'+args.task, config=args)        

    # create policy trainer
    policy_trainer = MFPolicyTrainer(
        policy=policy,
        eval_env=env,
        buffer=buffer,
        logger=logger,
        epoch=args.epoch,
        step_per_epoch=args.step_per_epoch,
        batch_size=args.batch_size,
        eval_episodes=args.eval_episodes,
        use_wandb=args.use_wandb,
        osd_batch_size=args.osd_batch_size,
        warmup_epoch = args.warmup_epoch
    )

    # train
    policy_trainer.train()


if __name__ == "__main__":
    train()