from dataclasses import asdict, dataclass
import os
import uuid
import numpy as np
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import yaml

TensorBatch = List[torch.Tensor]

ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate")

@dataclass
class TrainConfig:
    #need to concern
    device: str = "cuda:0"
    env: str = "antmaze-umaze-v0"  # OpenAI gym environment name
    seed: int = 300  # Sets Gym, PyTorch and Numpy seeds
    algo: str = "cql" #cql, pes-cql
    online_ratio: float = 0.5
    online_initial_size: int = 10000
    threshold_distance: float = 10.
    threshold_coefficient: float = 2.
    checkpoints_path: Optional[str] = "./checkpoints"  # Save path
    results_path: Optional[str] = "./results" # results save path
    load_model: bool = False  # if load model
    max_sequence_length: int = 1000
    num_sequence: int = 10
    heterogeneous: bool = True

    # Experiment
    eval_seed: int = 0  # Eval environment seed
    eval_freq: int = int(1e4)  # How often (time steps) we evaluate
    save_checkpoints_freq: int = int(1e4)
    n_episodes: int = 10  # How many episodes run during evaluation
    offline_iterations: int = int(1e6)  # Number of offline updates
    online_iterations: int = int(1e6)  # Number of online updates
    # CQL
    buffer_size: int = 2_000_000  # Replay buffer size
    batch_size: int = 256  # Batch size for all networks
    discount: float = 0.99  # Discount factor
    alpha_multiplier: float = 1.0  # Multiplier for alpha in loss
    use_automatic_entropy_tuning: bool = True  # Tune entropy
    backup_entropy: bool = False  # Use backup entropy
    policy_lr: float = 3e-5  # Policy learning rate
    qf_lr: float = 3e-4  # Critics learning rate
    soft_target_update_rate: float = 5e-3  # Target network update rate
    bc_steps: int = int(0)  # Number of BC steps at start
    target_update_period: int = 1  # Frequency of target nets updates
    cql_alpha: float = 10.0  # CQL offline regularization parameter
    cql_alpha_online: float = 10.0  # CQL online regularization parameter
    cql_n_actions: int = 10  # Number of sampled actions
    cql_importance_sample: bool = True  # Use importance sampling
    cql_lagrange: bool = False  # Use Lagrange version of CQL
    cql_target_action_gap: float = -1.0  # Action gap
    cql_temp: float = 1.0  # CQL temperature
    cql_max_target_backup: bool = False  # Use max target backup
    cql_clip_diff_min: float = -np.inf  # Q-function lower loss clipping
    cql_clip_diff_max: float = np.inf  # Q-function upper loss clipping
    orthogonal_init: bool = True  # Orthogonal initialization
    normalize: bool = True  # Normalize states
    normalize_reward: bool = False  # Normalize reward
    q_n_hidden_layers: int = 2  # Number of hidden layers in Q networks
    reward_scale: float = 1.0  # Reward scale for normalization
    reward_bias: float = 0.0  # Reward bias for normalization

    def __post_init__(self):
        if self.algo == 'pes-cql':
            with open('../configs/distance_sa.yaml', 'r') as file:
                distance_dict = yaml.load(file, Loader=yaml.FullLoader)
            if self.env in distance_dict.keys():
                self.threshold_distance = self.threshold_coefficient * distance_dict[self.env]

        if self.env.startswith("antmaze"):
            self.cql_alpha = 5.0
            self.cql_alpha_online = 5.0
            self.cql_clip_diff_min = -200
            self.cql_lagrange = True
            self.cql_max_target_backup = True
            self.cql_target_action_gap = 0.8
            self.n_episodes = 100
            self.normalize = False
            self.normalize_reward = True
            self.policy_lr = 0.0001
            self.q_n_hidden_layers = 5
            self.reward_scale = 10.0
            self.reward_bias = -5.0
        
        if self.env.startswith("pen") or self.env.startswith("hammer") or self.env.startswith("door") or self.env.startswith("relocate"):
            self.cql_clip_diff_min = -200
            self.cql_importance_sample = True
            self.cql_lagrange = False
            self.cql_max_target_backup = True
            self.cql_alpha = 1.0
            self.cql_alpha_online = 1.0
            self.cql_n_actions = 10
            self.cql_target_action_gap = 0.8
            self.cql_temp = 1.0
            self.normalize = False
            self.normalize_reward = False
            self.orthogonal_init = True
            self.policy_lr = 0.0001
            self.qf_lr = 0.0003
            self.soft_target_update_rate = 0.005
            self.target_update_period = 1
            self.q_n_hidden_layers = 3
            self.reward_scale = 1.0
            self.reward_bias = 0.0
            self.use_automatic_entropy_tuning = True
        
        self.eval_seed = self.seed
        self.name = f"{self.env}-{self.seed}"
        #self.name = f"{self.name}-{self.env}-{str(uuid.uuid4())[:8]}"
        if self.results_path is not None:
            self.results_path = self.results_path + "_" + "heter_" + self.algo
            self.results_path = os.path.join(self.results_path, self.env, f"{self.seed}.txt")
        
        if self.checkpoints_path is not None:
            self.checkpoints_path = os.path.join(self.checkpoints_path, self.env, str(self.seed))
            
        results_directory = os.path.dirname(self.results_path)
        if not os.path.exists(results_directory):
            os.makedirs(results_directory)
        if not os.path.exists(self.checkpoints_path):
            os.makedirs(self.checkpoints_path)
            
            
            