import os
import random
import uuid
from copy import deepcopy
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional
from tqdm import trange


@dataclass
class TrainConfig:
    project: str = "CORL"
    group: str = "AWAC-D4RL"
    name: str = "AWAC"
    algo: str = "awac"
    env_name: str = "halfcheetah-medium-expert-v2"
    seed: int = 100
    eval_seed: int = 0  # Eval environment seed
    test_seed: int = 69
    device: str = "cuda"
    
    checkpoints_path: Optional[str] = "./checkpoints"  # Save path
    results_path: Optional[str] = "./results" # results save path
    save_checkpoints_freq: int = int(1e4)
    
    online_ratio: float = 0.5 #训练时online数据的比例
    online_initial_size: int = 10000
    threshold_distance: float = 10.
    threshold_coefficient: float = 2
    load_model: bool = True  # 是否加载模型
    max_sequence_length: int = 1000
    num_sequence: int = 10
    
    
    buffer_size: int = 2_000_000
    offline_iterations: int = int(1e6)  # Number of offline updates
    online_iterations: int = int(1e6)  # Number of online updates
    batch_size: int = 256
    eval_frequency: int = 10000
    n_test_episodes: int = 10
    normalize_reward: bool = False

    hidden_dim: int = 256
    learning_rate: float = 3e-4
    gamma: float = 0.99
    tau: float = 5e-3
    awac_lambda: float = 1.0

    def __post_init__(self):
        
        self.name = f"{self.env_name}-{self.seed}"
        if self.results_path is not None:
            self.results_path = os.path.join(self.results_path, self.algo, self.env_name, f"{self.seed}.txt")
        
        if self.checkpoints_path is not None:
            self.checkpoints_path = os.path.join(self.checkpoints_path, self.env_name, str(self.seed))
            
        results_directory = os.path.dirname(self.results_path)
        if not os.path.exists(results_directory):
            os.makedirs(results_directory)
        if not os.path.exists(self.checkpoints_path):
            os.makedirs(self.checkpoints_path)
            
        if 'antmaze' in self.env_name:
            self.awac_lambda = 0.1
            self.gamma = 0.99
            self.hidden_dim = 256
            self.learning_rate = 0.0003
            self.n_test_episodes = 100
            self.normalize_reward = True
            self.tau = 0.005
            
        if 'door' in self.env_name:
            self.awac_lambda = 0.1
            self.gamma = 0.99
            self.hidden_dim = 256
            self.learning_rate = 0.0003
            self.n_test_episodes = 10
            self.normalize_reward = False
            
        if 'pen' in self.env_name:
            self.awac_lambda = 0.1
            self.gamma = 0.99
            self.hidden_dim = 256
            self.learning_rate = 0.0003
            self.n_test_episodes = 10
            self.normalize_reward = False
            self.tau = 0.005
            
        if 'hammer' in self.env_name:
            self.awac_lambda = 0.1
            self.gamma = 0.99
            self.hidden_dim = 256
            self.learning_rate = 0.0003
            self.n_test_episodes = 10
            self.normalize_reward = False
            self.tau = 0.005
            
        if 'relocate' in self.env_name:
            self.awac_lambda = 0.1
            self.gamma = 0.99
            self.hidden_dim = 256
            self.learning_rate = 0.0003
            self.n_test_episodes = 10
            self.normalize_reward = False
            self.tau = 0.005