import os
import uuid
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import yaml

TensorBatch = List[torch.Tensor]

ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate")

@dataclass
class TrainConfig:
    # Experiment
    device: str = "cuda:1"
    env: str = "halfcheetah-medium-v2"  # OpenAI gym environment name
    seed: int = 100  # Sets Gym, PyTorch and Numpy seeds
    algo: str = "td3-bc" #pes-td3-bc, td3-bc
    online_ratio: float = 0.5
    online_initial_size: int = 10000
    threshold_distance: float = 10.
    threshold_coefficient: float = 2.
    
    heterogeneous: bool = True
    eval_freq: int = int(1e3)  # How often (time steps) we evaluate
    n_episodes: int = 10  # How many episodes run during evaluation
    max_timesteps: int = int(1e6)  # Max time steps to run environment
    load_model: bool = False  # Model load file name, "" doesn't load
    checkpoints_path: Optional[str] = "./checkpoints"  # Save path
    results_path: Optional[str] = "./results" # results save path
    max_sequence_length: int = 1000
    num_sequence: int = 10
    
    
    save_checkpoints_freq: int = int(1e4)
    offline_iterations: int = int(1e6)  # Number of offline updates
    online_iterations: int = int(1e6)  # Number of online updates
    
    # TD3
    buffer_size: int = 2_000_000  # Replay buffer size
    batch_size: int = 256  # Batch size for all networks
    discount: float = 0.99  # Discount ffor
    expl_noise: float = 0.1  # Std of Gaussian exploration noise
    tau: float = 0.005  # Target network update rate
    policy_noise: float = 0.2  # Noise added to target actor during critic update
    noise_clip: float = 0.5  # Range to clip target actor noise
    policy_freq: int = 2  # Frequency of delayed actor updates
    # TD3 + BC
    alpha: float = 2.5  # Coefficient for Q function in actor loss
    normalize: bool = True  # Normalize states
    normalize_reward: bool = False  # Normalize reward
    reward_scale: float = 1.0  # Reward scale for normalization
    reward_bias: float = 0.0  # Reward bias for normalization

    def __post_init__(self):
        self.eval_seed = self.seed
        self.name = f"{self.env}-{self.seed}"
        #self.name = f"{self.name}-{self.env}-{str(uuid.uuid4())[:8]}"
        if self.results_path is not None:
            self.results_path = self.results_path + "_" + "heter_" + self.algo
            self.results_path = os.path.join(self.results_path, self.env, f"{self.seed}.txt")
        
        if self.checkpoints_path is not None:
            self.checkpoints_path = os.path.join(self.checkpoints_path, self.env, str(self.seed))
            
            
        if self.algo == 'pes-td3-bc':
            with open('../configs/distance_sa.yaml', 'r') as file:
                distance_dict = yaml.load(file, Loader=yaml.FullLoader)
            
            if self.env in distance_dict.keys():
                self.threshold_distance = self.threshold_coefficient * distance_dict[self.env]
            
        results_directory = os.path.dirname(self.results_path)
        if not os.path.exists(results_directory):
            os.makedirs(results_directory)
        if not os.path.exists(self.checkpoints_path):
            os.makedirs(self.checkpoints_path)