import hydra
import wandb
import random
import minari
import os
import numpy as np
import gymnasium as gym
from pathlib import Path
from datetime import datetime
from omegaconf import DictConfig
from pathlib import Path

import torch
import torch.distributions as td
import torch.nn.functional as F
from torch.utils.data import DataLoader

from model import DecisionTransformer, DecisionConvTransformer
from utils import KMeansEpisodicTrajectoryDataset, MaxEpisodicTrajectoryDataset, VisionKMeansEpisodicTrajectoryDataset, VisionMaxEpisodicTrajectoryDataset, get_test_start_state_goals, get_lr, AntmazeWrapper, PreprocessObservationWrapper 
from gymnasium.wrappers import PixelObservationWrapper
from torch import distributions, nn
from torch.nn import functional as F
from functools import partial
from typing import List, Optional, Type, Any, Dict, Optional, Tuple, Union

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def eval_env(cfg, model, device, render=False):
    if render:
        render_mode = 'human'
    else:
        render_mode = None

    if cfg.env_name in ['PointMaze_UMaze-v3', 'PointMaze_Medium-v3', 'PointMaze_Large-v3']:
        if cfg.vision:
            env = gym.make(cfg.env_name, continuing_task=False, render_mode='rgb_array')
            DEFAULT_CAMERA_CONFIG = {
            "distance": 14 if len(env.maze.maze_map) > 8 else 8.8,
            "elevation": -90,
            "lookat": [0, 0, 0,],
            }
            env.point_env.mujoco_renderer.default_cam_config = DEFAULT_CAMERA_CONFIG
            env = PixelObservationWrapper(env, pixels_only=False)
            env = PreprocessObservationWrapper(env, shape=64, grayscale=True)
        else:
            env = env = gym.make(cfg.env_name, continuing_task=False, render_mode=render_mode)
    elif cfg.env_name in ['AntMaze_UMaze-v4', 'AntMaze_Medium-v4', 'AntMaze_Large-v4']:
        env = AntmazeWrapper(env = gym.make(cfg.env_name, continuing_task=False, render_mode=render_mode))
    else:
        raise NotImplementedError 

    test_start_state_goal = get_test_start_state_goals(cfg)
    
    model.eval()
    results = dict()
    eval_batch_size = 1
    with torch.no_grad():
        cum_reward = 0
        for ss_g in test_start_state_goal:
            total_reward = 0
            total_timesteps = 0
            print(ss_g['name'] + ':')
            for _ in range(cfg.num_eval_ep):
                # zeros place holders
                m_actions = torch.zeros((eval_batch_size, env.spec.max_episode_steps, model.act_dim),
                                    dtype=torch.float32, device=device)
                m_proprios = torch.zeros((eval_batch_size, env.spec.max_episode_steps, model.state_dim - 2),
                                    dtype=torch.float32, device=device)
                h,w,c = env.observation_space['pixels'].shape
                m_states = torch.zeros((eval_batch_size, env.spec.max_episode_steps, c, h, w),
                                    dtype=torch.float32, device=device)
                m_goals = torch.zeros((eval_batch_size, env.spec.max_episode_steps, model.goal_dim),
                                    dtype=torch.float32, device=device)
                
                obs, _ = env.reset(options=ss_g)
                done = False

                for t in range(env.spec.max_episode_steps):
                    total_timesteps += 1

                    m_states[0,t] = torch.tensor(obs['pixels'], dtype=torch.float32, device=device).squeeze(-1).unsqueeze(0)
                    m_proprios[0, t] = torch.tensor(obs['observation'], dtype=torch.float32, device=device)
                    m_goals[0, t] = torch.tensor(obs['desired_goal'], dtype=torch.float32, device=device)
                    

                    if t < cfg.context_len:
                        act_preds = model.forward(m_states[:,:t+1],
                                                    m_proprios[:,:t+1],
                                                    m_actions[:,:t+1],
                                                    m_goals[:,:t+1])
                                            
                    else:
                        act_preds = model.forward(m_states[:, t-cfg.context_len+1:t+1],
                                                    m_proprios[:, t-cfg.context_len+1:t+1],
                                                    m_actions[:, t-cfg.context_len+1:t+1],
                                                    m_goals[:, t-cfg.context_len+1:t+1])
                        

                    act = act_preds[0, -1].detach()

                    obs, running_reward, done, _, _ = env.step(act.cpu().numpy())

                    # add action in placeholder
                    m_actions[0, t] = act

                    total_reward += running_reward

                    if done:
                        break

                print('Achievied goal: ', tuple(obs['achieved_goal'].tolist()))
                print('Desired goal: ', tuple(obs['desired_goal'].tolist()))
                
            print("=" * 60)
            cum_reward += total_reward
            results['eval/' + str(ss_g['name']) + '_avg_reward'] = total_reward / cfg.num_eval_ep
            results['eval/' + str(ss_g['name']) + '_avg_ep_len'] = total_timesteps / cfg.num_eval_ep
        
        results['eval/avg_reward'] = cum_reward / (cfg.num_eval_ep * len(test_start_state_goal))
        env.close()
    return results

class VAE(nn.Module):
    # Vanilla Variational Auto-Encoder
    def __init__(self,state_dim: int,
                    action_dim: int,
                    goal_dim: int,
                    latent_dim: int,
                    max_action: float,
                    hidden_dim: int = 750):

        super(VAE, self).__init__()
        if latent_dim is None:
            latent_dim = 2 * action_dim
        #
        self.encoder_shared = nn.Sequential( nn.Linear(state_dim + action_dim + goal_dim, hidden_dim),
                                             nn.ReLU(),
                                             nn.Linear(hidden_dim, hidden_dim),
                                             nn.ReLU() )

        self.mean    = nn.Linear(hidden_dim, latent_dim)
        self.log_std = nn.Linear(hidden_dim, latent_dim)
        self.decoder = nn.Sequential( nn.Linear(state_dim + action_dim + latent_dim, hidden_dim),
                                      nn.ReLU(),
                                      nn.Linear(hidden_dim, hidden_dim),
                                      nn.ReLU(),
                                      nn.Linear(hidden_dim, goal_dim),
                                      nn.Sigmoid())
        
        self.max_action = max_action
        self.latent_dim = latent_dim
        self.device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def forward(self, state : torch.Tensor,
                      action: torch.Tensor,
                      goal  : torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mean, std = self.encode(state, action, goal)
        z = mean + std * torch.randn_like(std)
        u = self.decode(state, action, z)
        return u, mean, std

    def importance_sampling_estimator(self, state: torch.Tensor,
                                            action: torch.Tensor,
                                            goal:torch.Tensor,
                                            beta: float,
                                            num_samples: int = 500) -> torch.Tensor:
        # * num_samples correspond to num of samples L in the paper
        # * note that for exact value for \hat \log \pi_\beta in the paper
        # we also need **an expection over L samples**
        mean, std = self.encode(state, action, goal)
        #print("mean.shape:",mean.shape)

        mean_enc = mean.repeat(num_samples, 1, 1, 1).permute(1, 0, 2, 3)  # [B x S x D]
        
        
        std_enc = std.repeat(num_samples, 1, 1, 1).permute(1, 0, 2, 3)  # [B x S x D]
        z = mean_enc + std_enc * torch.randn_like(std_enc)  # [B x S x D]

        state = state.repeat(num_samples, 1, 1, 1).permute(1, 0, 2, 3)  # [B x S x C]
        action = action.repeat(num_samples, 1, 1, 1).permute(1, 0, 2, 3)  # [B x S x C]
        goal_ = goal.repeat(num_samples, 1, 1, 1).permute(1, 0, 2, 3)  # [B x S x C]
        mean_dec = self.decode(state, action, z)
        #print("mean_enc.shape:",mean_enc.shape)

        std_dec = np.sqrt(beta / 4)
        # Find q(z|x)
        log_qzx = td.Normal(loc=mean_enc, scale=std_enc).log_prob(z)
        # Find p(z)
        mu_prior = torch.zeros_like(z).to(self.device)
        std_prior = torch.ones_like(z).to(self.device)
        log_pz = td.Normal(loc=mu_prior, scale=std_prior).log_prob(z)
        # Find p(x|z)
        std_dec = torch.ones_like(mean_dec).to(self.device) * std_dec
        #
        log_pxz = td.Normal(loc=mean_dec, scale=std_dec).log_prob(goal_)
        #print("log_pxz.shape:",log_pxz.shape)

        w = log_pxz.sum(-1) + log_pz.sum(-1) - log_qzx.sum(-1)
        #print("w.shape:",w.shape)
        ll = w.logsumexp(dim=1) - np.log(num_samples)
        #print("ll.shape:",ll.shape)
        return ll

    def encode(self, state : torch.Tensor,
                     action: torch.Tensor,
                     goal  : torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]:
        # goal: (batchsize, 1)
        z = self.encoder_shared(torch.cat([state, action, goal], -1))
        mean = self.mean(z)
        # Clamped for numerical stability
        log_std = self.log_std(z).clamp(-4, 15)
        std = torch.exp(log_std)
        return mean, std

    def decode(self, state: torch.Tensor,
                     action: torch.Tensor,
                     z: torch.Tensor = None,) -> torch.Tensor:
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        # [state, action, z]-> -> prob (goal)
        if z is None:
            z = (torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-0.5, 0.5))
        mid_feature = torch.cat([action, state, z], -1)
        return self.decoder(mid_feature)

    def save_model(self, path: str):
        torch.save(self.state_dict(), path)

    def load_model(self, path: str):
        self.load_state_dict(torch.load(path))
        
def train(cfg, hydra_cfg):

    #set seed
    set_seed(cfg.seed)

    #set device
    device = torch.device(cfg.device)

    if cfg.save_snapshot:
        checkpoint_path = Path(hydra_cfg['runtime']['output_dir']) / Path('checkpoint')
        checkpoint_path.mkdir(exist_ok=True)
        log_path =  Path(hydra_cfg['runtime']['output_dir'])
        best_eval_returns = 0

    start_time = datetime.now().replace(microsecond=0)
    time_elapsed = start_time - start_time
    start_time_str = start_time.strftime("%y-%m-%d-%H-%M-%S")

    if cfg.dataset_name in ["v-pointmaze-umaze-v0", "v-pointmaze-medium-v0", "v-pointmaze-large-v0"]:
        cfg.vision = True
        if cfg.dataset_name == "v-pointmaze-umaze-v0":
            cfg.env_name = 'PointMaze_UMaze-v3'
            cfg.nclusters = 20 if cfg.nclusters is None else cfg.nclusters
        elif cfg.dataset_name == "v-pointmaze-medium-v0":
            cfg.env_name = 'PointMaze_Medium-v3'
            cfg.nclusters = 40 if cfg.nclusters is None else cfg.nclusters
        elif cfg.dataset_name == "v-pointmaze-large-v0":
            cfg.env_name = 'PointMaze_Large-v3'
            cfg.nclusters = 80 if cfg.nclusters is None else cfg.nclusters
        env = gym.make(cfg.env_name, continuing_task=False, render_mode='rgb_array')
        DEFAULT_CAMERA_CONFIG = {
        "distance": 14 if len(env.maze.maze_map) > 8 else 8.8,
        "elevation": -90,
        "lookat": [0, 0, 0,],
        }
        env.point_env.mujoco_renderer.default_cam_config = DEFAULT_CAMERA_CONFIG
        env = PixelObservationWrapper(env, pixels_only=False)
        env = PreprocessObservationWrapper(env, shape=64, grayscale=True)
    elif cfg.dataset_name in ["v-antmaze-umaze-v0", "v-antmaze-medium-v0", "v-antmaze-large-v0"]:
        cfg.vision = True
        if cfg.dataset_name == "v-antmaze-umaze-v0":
            cfg.env_name = 'AntMaze_UMaze-v4'
            distance = 20
            cfg.nclusters = 20 if cfg.nclusters is None else cfg.nclusters
        elif cfg.dataset_name == "v-antmaze-medium-v0":
            cfg.env_name = 'AntMaze_Medium-v4'
            distance = 32
            cfg.nclusters = 40 if cfg.nclusters is None else cfg.nclusters
        elif cfg.dataset_name == "v-antmaze-large-v0":
            cfg.env_name = 'AntMaze_Large-v4'
            distance = 40
            cfg.nclusters = 80 if cfg.nclusters is None else cfg.nclusters
        
        # environment initialisation
        env = AntmazeWrapper(gym.make(cfg.env_name, continuing_task=False, render_mode='rgb_array')) #, render_mode="human"))
        
        env = PixelObservationWrapper(env, pixels_only=False)
        env = PreprocessObservationWrapper(env, shape=64, grayscale=True)

        DEFAULT_CAMERA_CONFIG = {
            "distance": distance,
            "elevation": -90,
            "lookat": [0, 0, 0,],
        }
        env.ant_env.camera_id = -1
        env.ant_env.mujoco_renderer.default_cam_config = DEFAULT_CAMERA_CONFIG
    else:
        raise NotImplementedError
    env.action_space.seed(cfg.seed)
    env.observation_space.seed(cfg.seed)

    print(cfg.nclusters)


    train_dataset = VisionMaxEpisodicTrajectoryDataset(cfg.dataset_name, cfg. datasize, cfg.context_len, cfg.augment_data, cfg.augment_prob, cfg.nclusters, cfg.vision)  
    
    train_data_loader = DataLoader(
                            train_dataset,
                            batch_size=cfg.batch_size,
                            shuffle=True,
                            num_workers=cfg.num_workers
                        )
    train_data_iter = iter(train_data_loader)

    #create model
    model = DecisionConvTransformer(cfg.env_name, env, cfg.n_blocks, cfg.embed_dim, cfg.context_len, cfg.n_heads, cfg.drop_p, goal_dim=train_dataset.goal_dim).to(device)
    
    
    optimizer = torch.optim.AdamW(
                        model.parameters(),
                        lr=cfg.lr,
                        weight_decay=cfg.wt_decay
                    )

    scheduler = torch.optim.lr_scheduler.LambdaLR(
                            optimizer,
                            lambda steps: min((steps+1)/cfg.warmup_steps, 1)
                        )
    
    try:
                #states, proprio, goals, actions = next(train_data_iter)
                _, state, proprio, goal, action, _ = next(train_data_iter)

    except StopIteration:
        train_data_iter = iter(train_data_loader)
        #states, proprio, goals, actions = next(train_data_iter)
        _, state, proprio ,goal, action, _ = next(train_data_iter)

    state_dim = state.shape[-1]
    action_dim = action.shape[-1]
    goal_dim = goal.shape[-1]

    max_action = float(env.action_space.high[0])

    #VAE Train
    vae = VAE(
        state_dim, action_dim, goal_dim, 2 *goal_dim, max_action, cfg.vae_hidden_dim
    ).to(device)

    directory_path = os.path.join(cfg.vae_model_path, cfg.dataset_name)

    path = Path(directory_path)
    if not path.exists():
        path.mkdir(parents=True, exist_ok=True)
        print(f"Directory '{path}' created.")
    else:
        print(f"Directory '{path}' already exists.")

    model_path = os.path.join(path, 'vae_model.pth')
    if os.path.exists(model_path):
        print("Loading existing VAE model...")
        vae.load_model(model_path)
    else:
        print("Training VAE!!!!!!!!!!")
        vae_optimizer = torch.optim.Adam(vae.parameters(), lr=cfg.vae_lr)
        #cfg.vae_iterations
        for t in range(int(cfg.vae_iterations)):
            try:
                #states, proprio, goals, actions = next(train_data_iter)
                timesteps, states, goal, actions, traj_mask = next(train_data_iter)

            except StopIteration:
                train_data_iter = iter(train_data_loader)
                #states, proprio, goals, actions = next(train_data_iter)
                timesteps, states, goal, actions, traj_mask = next(train_data_iter)
            
            if cfg.vision:
                states = states.to(device).squeeze(-1).unsqueeze(2)                                  # B x T x pixel_dim
            else:
                states = states.to(device)       
            
            #proprio = proprio.to(device)                                # B x T x state_dim
            goal = goal.to(device).repeat(1, cfg.context_len, 1)      # B x T x goal_dim
            actions = actions.to(device)
            
            log_dict = {}
            # Variational Auto-Encoder Training
            recon, mean, std = vae(states, actions, goal)
            #print("recon.shape:",recon.shape)
            #print("goal.shape:",goal.shape)
            recon_loss = F.mse_loss(recon, goal)
            KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
            vae_loss = recon_loss + cfg.beta * KL_loss
    
            vae_optimizer.zero_grad()
            vae_loss.backward()
            vae_optimizer.step()
    
            log_dict["VAE/reconstruction_loss"] = recon_loss.item()
            log_dict["VAE/KL_loss"] = KL_loss.item()
            log_dict["VAE/vae_loss"] = vae_loss.item()
            log_dict["vae_iter"] = t
        print("train vae results:", log_dict)
        print("Saving VAE model...")
        vae.save_model(model_path)

    vae.eval()

    total_updates = 0
    for i_train_iter in range(cfg.max_train_iters):
        
        log_action_losses = []
        model.train()

        for i in range(cfg.num_updates_per_iter):
            #print(i)
            try:
                states, proprio, goals, actions = next(train_data_iter)
                #states, goals, actions = next(train_data_iter)

            except StopIteration:
                train_data_iter = iter(train_data_loader)
                states, proprio, goals, actions = next(train_data_iter)
                #states, goals, actions = next(train_data_iter)


            if cfg.vision:
                states = states.to(device).squeeze(-1).unsqueeze(2)                                  # B x T x pixel_dim
            else:
                states = states.to(device)       
            
            proprio = proprio.to(device)                                # B x T x state_dim
            goals = goals.to(device).repeat(1, cfg.context_len, 1)      # B x T x goal_dim
            actions = actions.to(device)                                # B x T
            #traj_masks = traj_masks.to(device)                          # B x T

            action_preds = model.forward(
                                states=states, 
                                proprio=proprio,
                                actions=actions,
                                goals=goals,
                            )
            
            #action_preds = action_preds.view(-1, model.act_dim)[traj_masks.view(-1,) > 0]
            #actions = actions.view(-1, model.act_dim)[traj_masks.view(-1,) > 0]
            
            action_loss = F.mse_loss(action_preds, actions)

            optimizer.zero_grad()
            action_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
            optimizer.step()
            scheduler.step()

            log_action_losses.append(action_loss.detach().cpu().item())

        time = datetime.now().replace(microsecond=0) - start_time - time_elapsed
        time_elapsed = datetime.now().replace(microsecond=0) - start_time

        total_updates += cfg.num_updates_per_iter
        
        mean_action_loss = np.mean(log_action_losses)
        
        results = eval_env(cfg, model, device, render=cfg.render)

        log_str = ("=" * 60 + '\n' +
                "time elapsed: " + str(time_elapsed)  + '\n' +
                "num of updates: " + str(total_updates) + '\n' +
                "train action loss: " +  format(mean_action_loss, ".5f") #+ '\n' +
            )
        
        print(results)
        print(log_str)

        if cfg.wandb_log:
            log_data = dict()
            log_data['time'] =  time.total_seconds()
            log_data['time_elapsed'] =  time_elapsed.total_seconds()
            log_data['total_updates'] =  total_updates
            log_data['mean_action_loss'] =  mean_action_loss
            log_data['lr'] = get_lr(optimizer)
            log_data['training_iter'] = i_train_iter
            log_data.update(results)
            wandb.log(log_data)

        if cfg.save_snapshot and (1+i_train_iter)%cfg.save_snapshot_interval == 0:
            snapshot = Path(checkpoint_path) / Path(str(i_train_iter)+'.pt')
            torch.save(model.state_dict(), snapshot)

        if cfg.save_snapshot and results['eval/avg_reward'] >= best_eval_returns:
            print("=" * 60)
            print("saving best model!")
            print("=" * 60)
            best_eval_returns = results['eval/avg_reward']
            snapshot = Path(checkpoint_path) / 'best.pt'
            torch.save(model.state_dict(), snapshot)
            print("*******************************************************************************")
            print("total_updates:!!!!!!!!!!!!!!!!!",total_updates)
            print("*******************************************************************************")
            print("*******************************************************************************")
            print("best_eval_returns:!!!!!!!!!!!!!!!!!",best_eval_returns)
            print("*******************************************************************************")

    log_data = results
    log_filename = 'log.txt'
    if cfg.augment_data == True:
        new_log_filename = 'max_return_' +log_filename
    else:
        new_log_filename = 'no_augment_' +log_filename
    print("log_path:", log_path)
    
    log_file_path = os.path.join(log_path, new_log_filename)

    
    if not os.path.exists(log_path):
        os.makedirs(log_path)

    
    with open(log_file_path, 'w') as file:
        for key, value in log_data.items():
            file.write(f'{key}: {value}\n')

        print(f'Data has been written to {log_file_path}')

    print("=" * 60)
    print("finished training!")
    print("=" * 60)
    end_time = datetime.now().replace(microsecond=0)
    time_elapsed = str(end_time - start_time)
    end_time_str = end_time.strftime("%y-%m-%d-%H-%M-%S")
    print("started training at: " + start_time_str)
    print("finished training at: " + end_time_str)
    print("total training time: " + time_elapsed)
    print("*******************************************************************************")
    print("best_eval_returns:!!!!!!!!!!!!!!!!!",best_eval_returns)
    print("*******************************************************************************")
    print("=" * 60)

@hydra.main(config_path='cfgs', config_name='vision_max_dt', version_base=None)
def main(cfg: DictConfig):
    hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
    
    if cfg.wandb_log:
        if cfg.wandb_dir is None:
            cfg.wandb_dir = hydra_cfg['runtime']['output_dir']

        project_name = cfg.dataset_name
        wandb.init(project=project_name, entity=cfg.wandb_entity, config=dict(cfg), dir=cfg.wandb_dir, group=cfg.wandb_group_name)
        wandb.run.name = cfg.wandb_run_name
    
    train(cfg, hydra_cfg)
        
if __name__ == "__main__":
    main()