import argparse
import os
import random
from datetime import datetime
import time

import d4rl
import gym
import numpy as np
import torch
import wandb
from torch.utils.data import DataLoader
from torch import nn
from typing import Tuple
import torch.distributions as td
from pathlib import Path
from torch.nn import functional as F

from data import D4RLTrajectoryDataset
from trainer import ReinFormerTrainer
from eval import Reinformer_eval

class VAE(nn.Module):
    # Vanilla Variational Auto-Encoder
    def __init__(self,state_dim: int,
                    action_dim: int,
                    latent_dim: int,
                    max_action: float,
                    hidden_dim: int = 750):

        super(VAE, self).__init__()
        if latent_dim is None:
            latent_dim = 2 * action_dim
        #
        self.encoder_shared = nn.Sequential( nn.Linear(state_dim + action_dim, hidden_dim),
                                             nn.ReLU(),
                                             nn.Linear(hidden_dim, hidden_dim),
                                             nn.ReLU() )

        self.mean    = nn.Linear(hidden_dim, latent_dim)
        self.log_std = nn.Linear(hidden_dim, latent_dim)
        self.decoder = nn.Sequential( nn.Linear(state_dim + latent_dim, hidden_dim),
                                      nn.ReLU(),
                                      nn.Linear(hidden_dim, hidden_dim),
                                      nn.ReLU(),
                                      nn.Linear(hidden_dim, action_dim),
                                      nn.Sigmoid())
        
        self.max_action = max_action
        self.latent_dim = latent_dim
        self.device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def forward(self, state : torch.Tensor,
                      action: torch.Tensor
                      ) -> Tuple[torch.Tensor, torch.Tensor]:
        mean, std = self.encode(state, action)
        z = mean + std * torch.randn_like(std)
        u = self.decode(state, z)
        return u, mean, std

    def importance_sampling_estimator(self, state: torch.Tensor,
                                            action: torch.Tensor,
                                            beta: float,
                                            num_samples: int = 5) -> torch.Tensor:
        # * num_samples correspond to num of samples L in the paper
        # * note that for exact value for \hat \log \pi_\beta in the paper
        # we also need **an expection over L samples**
        mean, std = self.encode(state, action)
        #print("mean.shape:",mean.shape)

        mean_enc = mean.repeat(num_samples, 1, 1, 1).permute(1, 0, 2, 3)  # [B x S x D]
        
        
        std_enc = std.repeat(num_samples, 1, 1, 1).permute(1, 0, 2, 3)  # [B x S x D]
        z = mean_enc + std_enc * torch.randn_like(std_enc)  # [B x S x D]

        state = state.repeat(num_samples, 1, 1, 1).permute(1, 0, 2, 3)  # [B x S x C]
        action = action.repeat(num_samples, 1, 1, 1).permute(1, 0, 2, 3)  # [B x S x C]
        mean_dec = self.decode(state, z)
        #print("mean_enc.shape:",mean_enc.shape)

        std_dec = np.sqrt(beta / 4)
        # Find q(z|x)
        log_qzx = td.Normal(loc=mean_enc, scale=std_enc).log_prob(z)
        # Find p(z)
        mu_prior = torch.zeros_like(z)
        std_prior = torch.ones_like(z)
        log_pz = td.Normal(loc=mu_prior, scale=std_prior).log_prob(z)
        # Find p(x|z)
        std_dec = torch.ones_like(mean_dec) * std_dec
        #
        log_pxz = td.Normal(loc=mean_dec, scale=std_dec).log_prob(action)
        #print("log_pxz.shape:",log_pxz.shape)

        w = log_pxz.sum(-1) + log_pz.sum(-1) - log_qzx.sum(-1)
        #print("w.shape:",w.shape)
        ll = w.logsumexp(dim=1) - np.log(num_samples)
        #print("ll.shape:",ll.shape)
        return ll

    def encode(self, state : torch.Tensor,
                     action: torch.Tensor
                      ) -> Tuple[torch.Tensor, torch.Tensor]:
        # goal: (batchsize, 1)
        z = self.encoder_shared(torch.cat([state, action], -1))
        mean = self.mean(z)
        # Clamped for numerical stability
        log_std = self.log_std(z).clamp(-4, 15)
        std = torch.exp(log_std)
        return mean, std

    def decode(self, state: torch.Tensor,
                     z: torch.Tensor = None,) -> torch.Tensor:
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        # [state, action, z]-> -> prob (goal)
        if z is None:
            z = (torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-0.5, 0.5))
        mid_feature = torch.cat([state, z], -1)
        return self.decoder(mid_feature)

    def save_model(self, path: str):
        torch.save(self.state_dict(), path)

    def load_model(self, path: str):
        self.load_state_dict(torch.load(path))

def experiment(variant):
    # seeding
    seed = variant["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    env = variant["env"]
    dataset = variant["dataset"]
    
    if dataset == "complete":
        variant["batch_size"] = 16
    if env == "kitchen":
        d4rl_env = f"{env}-{dataset}-v0"
    elif env in ["pen", "door", "hammer", "relocate", "maze2d"]:
        d4rl_env = f"{env}-{dataset}-v1"
    elif env in ["halfcheetah", "hopper", "walker2d", "antmaze"]:
        d4rl_env = f"{env}-{dataset}-v2"
    if env in ["kitchen", "maze2d", "antmaze"]:
        variant["num_eval_ep"] = 100
    if env == "hopper":
        if dataset == "medium" or dataset == "meidum-replay":
            variant["batch_size"] = 256
    
    dataset_path = os.path.join(variant["dataset_dir"], f"{d4rl_env}.pkl")
    device = torch.device(variant["device"])

    start_time = datetime.now().replace(microsecond=0)
    start_time_str = start_time.strftime("%y-%m-%d-%H-%M-%S")

    print("=" * 60)
    print("start time: " + start_time_str)
    print("=" * 60)

    traj_dataset = D4RLTrajectoryDataset(
        env, dataset_path, variant["context_len"], device
    )

    traj_data_loader = DataLoader(
        traj_dataset,
        batch_size=variant["batch_size"],
        shuffle=True,
        pin_memory=True,
        drop_last=True,
    )

    data_iter = iter(traj_data_loader)

    state_mean, state_std = traj_dataset.get_state_stats()

    env = gym.make(d4rl_env)
    env.seed(seed)

    state_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    model_type = variant["model_type"]

    if model_type == "reinformer":
        Trainer = ReinFormerTrainer(
            state_dim=state_dim,
            act_dim=act_dim,
            device=device,
            variant=variant
        )
        def evaluator(model):
            return_mean, _, _, _ = Reinformer_eval(
                model=model,
                device=device,
                context_len=variant["context_len"],
                env = env,
                state_mean=state_mean,
                state_std=state_std,
                num_eval_ep=variant["num_eval_ep"],
                max_test_ep_len=variant["max_eval_ep_len"]
            )
            return env.get_normalized_score(
                return_mean
            ) * 100
    
        try:
                (
                    timesteps,
                    states,
                    next_states,
                    actions,
                    returns_to_go,
                    rewards,
                    traj_mask,
                ) = next(data_iter)
        except StopIteration:
                data_iter = iter(traj_data_loader)
                (
                    timesteps,
                    states,
                    next_states,
                    actions,
                    returns_to_go,
                    rewards,
                    traj_mask,
                ) = next(data_iter)

    state_dim = states.shape[-1]
    action_dim = actions.shape[-1]

    max_action = float(env.action_space.high[0])

    #VAE Train
    vae = VAE(
        state_dim, action_dim, 2 *action_dim,  max_action,variant["vae_hidden_dim"]
    )

    directory_path = os.path.join(variant["vae_model_path"], variant["dataset"])

    path = Path(directory_path)
    if not path.exists():
        path.mkdir(parents=True, exist_ok=True)
        print(f"Directory '{path}' created.")
    else:
        print(f"Directory '{path}' already exists.")

    model_path = os.path.join(path, 'vae_model.pth')
    if os.path.exists(model_path):
        print("Loading existing VAE model...")
        vae.load_model(model_path)
    else:
        print("Training VAE!!!!!!!!!!")
        vae_optimizer = torch.optim.Adam(vae.parameters(), lr=variant["vae_lr"])
        #variant["vae_iterations"]
        for t in range(int(variant["vae_iterations"])):
            print("Training VAE Steps:",t)
            try:
                (
                    timesteps,
                    states,
                    next_states,
                    actions,
                    returns_to_go,
                    rewards,
                    traj_mask,
                ) = next(data_iter)
            except StopIteration:
                data_iter = iter(traj_data_loader)
                (
                    timesteps,
                    states,
                    next_states,
                    actions,
                    returns_to_go,
                    rewards,
                    traj_mask,
                ) = next(data_iter)
            
            log_dict = {}
            # Variational Auto-Encoder Training
            recon, mean, std = vae(states, actions)
            #print("recon.shape:",recon.shape)
            #print("goal.shape:",goal.shape)
            recon_loss = F.mse_loss(recon, actions)
            KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
            vae_loss = recon_loss + variant["beta"] * KL_loss
    
            vae_optimizer.zero_grad()
            vae_loss.backward()
            vae_optimizer.step()
    
            log_dict["VAE/reconstruction_loss"] = recon_loss.item()
            log_dict["VAE/KL_loss"] = KL_loss.item()
            log_dict["VAE/vae_loss"] = vae_loss.item()
            log_dict["vae_iter"] = t
        print("train vae results:", log_dict)
        print("Saving VAE model...")
        vae.save_model(model_path)

    vae.eval()
    max_train_iters = variant["max_train_iters"]
    num_updates_per_iter = variant["num_updates_per_iter"]
    normalized_d4rl_score_list = []
    for _ in range(1, max_train_iters+1):
        t1 = time.time()
        print("iter",_)
        for epoch in range(num_updates_per_iter):
            print("epoch", epoch)
            try:
                (
                    timesteps,
                    states,
                    next_states,
                    actions,
                    returns_to_go,
                    rewards,
                    traj_mask,
                ) = next(data_iter)
            except StopIteration:
                data_iter = iter(traj_data_loader)
                (
                    timesteps,
                    states,
                    next_states,
                    actions,
                    returns_to_go,
                    rewards,
                    traj_mask,
                ) = next(data_iter)

            returns_to_go = vae.importance_sampling_estimator(states, actions, variant["beta"])

            loss = Trainer.train_step(
                timesteps=timesteps,
                states=states,
                next_states=next_states,
                actions=actions,
                returns_to_go=returns_to_go,
                rewards=rewards,
                traj_mask=traj_mask
            )
            if args.use_wandb:
                wandb.log(
                    data={
                        "training/loss" : loss,
                    }
                )
        t2 = time.time()
        normalized_d4rl_score = evaluator(
            model=Trainer.model
        )
        t3 = time.time()
        normalized_d4rl_score_list.append(normalized_d4rl_score)
        if args.use_wandb:
            wandb.log(
                data={
                        "training/time" : t2 - t1,
                        "evaluation/score" : normalized_d4rl_score,
                        "evaluation/time": t3 - t2
                    }
            )

    if args.use_wandb:
        wandb.log(
            data={
                "evaluation/max_score" : max(normalized_d4rl_score_list),
                "evaluation/last_score" : normalized_d4rl_score_list[-1]
            }
        )
    print(normalized_d4rl_score_list)
    print("=" * 60)
    print("finished training!")
    end_time = datetime.now().replace(microsecond=0)
    end_time_str = end_time.strftime("%y-%m-%d-%H-%M-%S")
    print("finished training at: " + end_time_str)
    print("=" * 60)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", choices=[ "reinformer"], default="reinformer")
    parser.add_argument("--env", type=str, default="hopper")
    parser.add_argument("--dataset", type=str, default="medium")
    parser.add_argument("--num_eval_ep", type=int, default=10)
    parser.add_argument("--max_eval_ep_len", type=int, default=1000)
    parser.add_argument("--dataset_dir", type=str, default="d4rl_dataset/")
    parser.add_argument("--context_len", type=int, default=5)
    parser.add_argument("--n_blocks", type=int, default=4)
    parser.add_argument("--embed_dim", type=int, default=256)  
    parser.add_argument("--n_heads", type=int, default=8)
    parser.add_argument("--dropout_p", type=float, default=0.1)
    parser.add_argument("--grad_norm", type=float, default=0.25)
    parser.add_argument("--tau", type=float, default=0.99)
    parser.add_argument("--batch_size", type=int, default=128)  
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--wd", type=float, default=1e-4)
    parser.add_argument("--warmup_steps", type=int, default=5000)
    parser.add_argument("--max_train_iters", type=int, default=10)
    parser.add_argument("--num_updates_per_iter", type=int, default=5000)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--seed", type=int, default=2024)
    parser.add_argument("--init_temperature", type=float, default=0.1)
    # use_wandb = False
    parser.add_argument("--use_wandb", action='store_true', default=False)

    #vae
    parser.add_argument("--vae_hidden_dim", type=int, default=750)
    parser.add_argument("--vae_iterations", type=int, default=100000)
    parser.add_argument("--vae_latent_dim", action='store_true', default=False)
    parser.add_argument("--vae_lr", type=float, default= 0.001)
    parser.add_argument("--beta", type=float, default=0.5)
    parser.add_argument("--vae_model_path", type=str, default="vae_model/")
    args = parser.parse_args()
    
    if args.use_wandb:
        wandb.init(
            name=args.env + "-" + args.dataset,
            project="Reinformer",
            config=vars(args)
        )

    experiment(vars(args))
