import numpy as np
import torch
import wandb
import torch.nn as nn
from typing import Dict, List, Union, Tuple, Optional, cast
import argparse
import random
import os
import gym
import d4rl
from torch.nn import functional as F

from offlinerlkit.nets import MLP
from offlinerlkit.modules import ActorProb, Critic, TanhDiagGaussian
from offlinerlkit.utils.noise import GaussianNoise
# from offlinerlkit.utils.load_dataset import qlearning_dataset
from buffer import qlearning_dataset, dice_dataset, ReplayBuffer
from offlinerlkit.utils.scaler import StandardScaler
# from offlinerlkit.buffer import ReplayBuffer
from offlinerlkit.utils.logger import Logger, make_log_dirs
# from offlinerlkit.policy_trainer import MFPolicyTrainer
from mf_policy_trainer import MFPolicyTrainer


from dw_cql import DWCQLPolicy

class MeanFlowConserveDiscrimanator(nn.Module):  # type: ignore


    _state_action_encoder: MLP
    _state_encoder: MLP
    _fc: nn.Linear

    def __init__(self, 
        state_action_encoder: MLP,
        state_encoder: MLP,
        discriminator_clip_ratio: float = None,
        discriminator_temp: float = None):
        super().__init__()

        self._state_action_encoder = state_action_encoder
        self._state_encoder = state_encoder

        self._state_action_fc = nn.Linear(256, 1)
        self._state_fc = nn.Linear(256, 1)

        self._discriminator_clip_ratio = discriminator_clip_ratio
        self._discriminator_temp = discriminator_temp

    def forward(self, x: torch.Tensor, actions: torch.Tensor = None) -> torch.Tensor:
        if actions is not None:
            z = torch.concat((x, actions),dim=1)
            return cast(torch.Tensor, self._state_action_fc(self._state_action_encoder(z))), cast(torch.Tensor, self._state_fc(self._state_encoder(x)))
        else:
            return cast(torch.Tensor, self._state_fc(self._state_encoder(x)))

    def compute_normalized_ratio_and_logits(self, 
            observation: torch.Tensor, 
            actions: torch.Tensor, 
            next_observation: torch.Tensor):
        
        logits_sa, logits_s = self.forward(observation, actions)
        logits_next_s = self.forward(next_observation)
        logits = logits_sa + logits_s

        dsa_clipped_logits = torch.clip(logits,
          -(1 + self._discriminator_clip_ratio),
          (1 + self._discriminator_clip_ratio))
        normalized_dsa_ratio = F.softmax(dsa_clipped_logits / self._discriminator_temp, dim=0)

        ds_clipped_logits = torch.clip(logits_s,
          -(1 + self._discriminator_clip_ratio),
          (1 + self._discriminator_clip_ratio))
        normalized_ds_ratio = F.softmax(ds_clipped_logits / self._discriminator_temp, dim=0)

        return normalized_dsa_ratio, normalized_ds_ratio, logits_sa, logits_s, logits_next_s
        
        
        
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--algo-name", type=str, default="dw-cql")
    parser.add_argument("--task", type=str, default="hopper-medium-v2")
    parser.add_argument("--seed", type=int, default=3)
    parser.add_argument("--hidden-dims", type=int, nargs='*', default=[256, 256, 256])
    parser.add_argument("--actor-lr", type=float, default=1e-4)
    parser.add_argument("--critic-lr", type=float, default=3e-4)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--tau", type=float, default=0.005)
    parser.add_argument("--alpha", type=float, default=0.2)
    parser.add_argument("--target-entropy", type=int, default=None)
    parser.add_argument("--auto-alpha", default=True)
    parser.add_argument("--alpha-lr", type=float, default=1e-4)

    parser.add_argument("--cql-weight", type=float, default=5.0)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--max-q-backup", type=bool, default=False)
    parser.add_argument("--deterministic-backup", type=bool, default=True)
    parser.add_argument("--with-lagrange", type=bool, default=False)
    parser.add_argument("--lagrange-threshold", type=float, default=10.0)
    parser.add_argument("--cql-alpha-lr", type=float, default=3e-4)
    parser.add_argument("--num-repeat-actions", type=int, default=10)
    
    parser.add_argument("--epoch", type=int, default=1000)
    parser.add_argument("--step-per-epoch", type=int, default=1000)
    parser.add_argument("--eval_episodes", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--device", type=str, default="cuda:3" if torch.cuda.is_available() else "cpu")


    parser.add_argument("--weight_type", type=str, default='median')
    parser.add_argument("--use_wandb", type=int, default=0)
    parser.add_argument("--discriminator_clip_ratio", type=float, default=0.2)
    parser.add_argument("--discriminator_weight_temp", type=float, default=1.0)
    parser.add_argument("--discriminator_lr", type=float, default=1e-4)


    return parser.parse_args()


def train(args=get_args()):
    # create env and dataset
    env = gym.make(args.task)
    # traj_returns, threshhold = pick_datasets(env)
    dataset = qlearning_dataset(env)
    # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22
    if 'antmaze' in args.task:
        dataset.rewards = (dataset.rewards - 0.5) * 4.0
    args.obs_shape = env.observation_space.shape
    args.action_dim = np.prod(env.action_space.shape)
    args.max_action = env.action_space.high[0]

    # seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    env.seed(args.seed)

    state_encoder = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=[256, 256])
    state_action_encoder = MLP(input_dim=np.prod(args.obs_shape) + + args.action_dim, hidden_dims=[256, 256])
    discriminator = MeanFlowConserveDiscrimanator(
            state_action_encoder, state_encoder,
            args.discriminator_clip_ratio,
            args.discriminator_weight_temp)
    discriminator.to(args.device)
    discriminator_optim = torch.optim.Adam(discriminator.parameters(), lr=args.discriminator_lr)



    # create policy model
    actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims)
    critic1_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims)
    critic2_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims)
    dist = TanhDiagGaussian(
        latent_dim=getattr(actor_backbone, "output_dim"),
        output_dim=args.action_dim,
        unbounded=True,
        conditioned_sigma=True
    )
    actor = ActorProb(actor_backbone, dist, args.device)
    critic1 = Critic(critic1_backbone, args.device)
    critic2 = Critic(critic2_backbone, args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
    critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
    critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

    if args.auto_alpha:
        target_entropy = args.target_entropy if args.target_entropy \
            else -np.prod(env.action_space.shape)

        args.target_entropy = target_entropy

        log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
        alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
        alpha = (target_entropy, log_alpha, alpha_optim)
    else:
        alpha = args.alpha

    # create policy
    policy = DWCQLPolicy(
        actor,
        critic1,
        critic2,
        actor_optim,
        critic1_optim,
        critic2_optim,
        discriminator=discriminator,
        discriminator_optim = discriminator_optim,
        weight_type = args.weight_type,
        action_space=env.action_space,
        tau=args.tau,
        gamma=args.gamma,
        alpha=alpha,
        cql_weight=args.cql_weight,
        temperature=args.temperature,
        max_q_backup=args.max_q_backup,
        deterministic_backup=args.deterministic_backup,
        with_lagrange=args.with_lagrange,
        lagrange_threshold=args.lagrange_threshold,
        cql_alpha_lr=args.cql_alpha_lr,
        num_repeart_actions=args.num_repeat_actions
    )

    # create buffer
    buffer = ReplayBuffer(
        buffer_size=len(dataset["observations"]),
        obs_shape=args.obs_shape,
        obs_dtype=np.float32,
        action_dim=args.action_dim,
        action_dtype=np.float32,
        device=args.device
    )
    buffer.load_dataset(dataset)

    # log
    log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args))
    # key: output file name, value: output handler type
    output_config = {
        "consoleout_backup": "stdout",
        "policy_training_progress": "csv",
        "tb": "tensorboard"
    }
    logger = Logger(log_dirs, output_config)
    logger.log_hyperparameters(vars(args))


    if args.use_wandb:
        wandb.init(project="dw", group='dw',
        name=args.task, config=args)

    if args.use_wandb:
        wandb.init(project="sun", group='cql',
        name='seed-'+str(args.seed)+'-'+args.task, config=args)        

    # create policy trainer
    policy_trainer = MFPolicyTrainer(
        policy=policy,
        eval_env=env,
        buffer=buffer,
        logger=logger,
        epoch=args.epoch,
        step_per_epoch=args.step_per_epoch,
        batch_size=args.batch_size,
        eval_episodes=args.eval_episodes,
        use_wandb=args.use_wandb
    )

    # train
    policy_trainer.train()


if __name__ == "__main__":
    train()