# source: https://github.com/nakamotoo/Cal-QL/tree/main
# https://arxiv.org/pdf/2303.05479.pdf
import os
import random
import uuid
from copy import deepcopy
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch.distributions import Normal, TanhTransform, TransformedDistribution
import yaml

TensorBatch = List[torch.Tensor]

ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate")

@dataclass
class TrainConfig:
    # Experiment
    algo: str = "calql"
    device: str = "cuda:0"
    env: str = "halfcheetah-medium-v2"  # OpenAI gym environment name
    seed: int = 300  # Sets Gym, PyTorch and Numpy seeds
    eval_seed: int = 300  # Eval environment seed
    eval_freq: int = int(1e4)  # How often (time steps) we evaluate
    n_episodes: int = 10  # How many episodes run during evaluation
    offline_iterations: int = int(1e6)  # Number of offline updates
    online_iterations: int = int(1e6)  # Number of online updates
    load_model: bool = True  # Model load file name, "" doesn't load
    checkpoints_path: Optional[str] = "./checkpoints"  # Save path
    results_path: Optional[str] = "./results" # results save path
    save_checkpoints_freq: int = int(1e4)
    online_initial_size: int = 1000
    threshold_distance: float = 10.
    threshold_coefficient: float = 3
    max_sequence_length: int = 1000
    num_sequence: int = 10
    
    # CQL
    buffer_size: int = 2_000_000  # Replay buffer size
    batch_size: int = 256  # Batch size for all networks
    discount: float = 0.99  # Discount factor
    alpha_multiplier: float = 1.0  # Multiplier for alpha in loss
    use_automatic_entropy_tuning: bool = True  # Tune entropy
    backup_entropy: bool = False  # Use backup entropy
    policy_lr: float = 3e-5  # Policy learning rate
    qf_lr: float = 3e-4  # Critics learning rate
    soft_target_update_rate: float = 5e-3  # Target network update rate
    bc_steps: int = int(0)  # Number of BC steps at start
    target_update_period: int = 1  # Frequency of target nets updates
    cql_alpha: float = 10.0  # CQL offline regularization parameter
    cql_alpha_online: float = 10.0  # CQL online regularization parameter
    cql_n_actions: int = 10  # Number of sampled actions
    cql_importance_sample: bool = True  # Use importance sampling
    cql_lagrange: bool = False  # Use Lagrange version of CQL
    cql_target_action_gap: float = -1.0  # Action gap
    cql_temp: float = 1.0  # CQL temperature
    cql_max_target_backup: bool = False  # Use max target backup
    cql_clip_diff_min: float = -np.inf  # Q-function lower loss clipping
    cql_clip_diff_max: float = np.inf  # Q-function upper loss clipping
    orthogonal_init: bool = True  # Orthogonal initialization
    normalize: bool = True  # Normalize states
    normalize_reward: bool = False  # Normalize reward
    q_n_hidden_layers: int = 2  # Number of hidden layers in Q networks
    reward_scale: float = 1.0  # Reward scale for normalization
    reward_bias: float = 0.0  # Reward bias for normalization
    # Cal-QL
    mixing_ratio: float = 0.5  # Data mixing ratio for online tuning
    is_sparse_reward: bool = False  # Use sparse reward

    def __post_init__(self):
        
        if self.env.startswith("antmaze"):
            self.reward_scale = 5.0
            self.results_bias = -1.0
            self.normalize_reward = True
        
        self.name = f"{self.name}-{self.env}-{str(uuid.uuid4())[:8]}"
        self.eval_seed = self.seed
            
        if self.results_path is not None:
            self.results_path = self.results_path + "_" + self.algo
            self.results_path = os.path.join(self.results_path, self.env, f"{self.seed}.txt")
            
        if self.checkpoints_path is not None:
            self.checkpoints_path = os.path.join(self.checkpoints_path, self.env, str(self.seed))
            
        results_directory = os.path.dirname(self.results_path)
        if not os.path.exists(results_directory):
            os.makedirs(results_directory)
        if not os.path.exists(self.checkpoints_path):
            os.makedirs(self.checkpoints_path)
            
        if self.algo == 'pes-calql':
            with open('../configs/distance_sa.yaml', 'r') as file:
                distance_dict = yaml.load(file, Loader=yaml.FullLoader)
            
            if self.env in distance_dict.keys():
                self.threshold_distance = self.threshold_coefficient * distance_dict[self.env]

        if 'antmaze' in self.env:
            self.alpha_multiplier = 1.0
            self.backup_entropy = False
            self.batch_size = 256
            self.bc_steps = 0
            self.cql_alpha = 5.0
            self.cql_alpha_online = 5.0
            self.cql_clip_diff_max = np.inf
            self.cql_clip_diff_min = -200
            self.cql_importance_sample = True
            self.cql_lagrange = True
            self.cql_max_target_backup = True
            self.cql_n_actions = 10
            self.cql_target_action_gap = 0.8
            self.cql_temp = 1.0
            self.discount = 0.99
            self.n_episodes = 100
            self.normalize = False
            self.normalize_reward = True
            self.orthogonal_init = True
            self.policy_lr = 0.0001
            self.qf_lr = 0.0003
            self.soft_target_update_rate = 0.005
            self.target_update_period = 1
            self.q_n_hidden_layers = 5
            self.reward_scale = 10.0
            self.reward_bias = -5.0
            self.use_automatic_entropy_tuning = True
            self.is_sparse_reward = True