# source: https://github.com/gwthomas/IQL-PyTorch
# https://arxiv.org/pdf/2110.06169.pdf
import copy
import os
import random
import uuid
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch.distributions import Normal
from torch.optim.lr_scheduler import CosineAnnealingLR
import yaml

@dataclass
class TrainConfig:
    # Experiment
    algo: str="pes-iql" #iql, pes-iql
    device: str = "cuda:0"
    env: str = "halfcheetah-medium-v2"  # OpenAI gym environment name
    seed: int = 300  # Sets Gym, PyTorch and Numpy seeds
    eval_seed: int = 0  # Eval environment seed
    eval_freq: int = int(1e4)  # How often (time steps) we evaluate
    n_episodes: int = 10  # How many episodes run during evaluation
    offline_iterations: int = int(1e6)  # Number of offline updates
    online_iterations: int = int(1e6)  # Number of online updates
    save_checkpoints_freq: int = int(1e4)
    
    online_ratio: float = 0.5 # online data ratio for fine-tunin g
    online_initial_size: int = 10000
    threshold_distance: float = 10.
    threshold_coefficient: float = 2
    checkpoints_path: Optional[str] = "./checkpoints"  # checkpoint save path
    results_path: Optional[str] = "./results" # results save path
    load_model: bool = True  # whether load pre-trained model
    max_sequence_length: int = 1000
    num_sequence: int = 5
    
    # IQL
    actor_dropout: float = 0.0  # Dropout in actor network
    buffer_size: int = 2_000_000  # Replay buffer size
    batch_size: int = 256  # Batch size for all networks
    discount: float = 0.99  # Discount factor
    tau: float = 0.005  # Target network update rate
    beta: float = 3.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
    iql_tau: float = 0.7  # Coefficient for asymmetric loss
    expl_noise: float = 0.03  # Std of Gaussian exploration noise
    noise_clip: float = 0.5  # Range to clip noise
    iql_deterministic: bool = False  # Use deterministic actor
    normalize: bool = True  # Normalize states
    normalize_reward: bool = False  # Normalize reward
    vf_lr: float = 3e-4  # V function learning rate
    qf_lr: float = 3e-4  # Critic learning rate
    actor_lr: float = 3e-4  # Actor learning rate

    def __post_init__(self):
        self.eval_seed = self.seed
        self.name = f"{self.env}-{self.seed}"
        if self.results_path is not None:
            self.results_path = self.results_path + "_" + self.algo
            self.results_path = os.path.join(self.results_path, self.env, str(self.threshold_coefficient), f"{self.seed}.txt")
            
        if self.checkpoints_path is not None:
            self.checkpoints_path = os.path.join(self.checkpoints_path, self.env, str(self.seed))
            
        results_directory = os.path.dirname(self.results_path)
        if not os.path.exists(results_directory):
            os.makedirs(results_directory)
        if not os.path.exists(self.checkpoints_path):
            os.makedirs(self.checkpoints_path)

        if self.algo == 'pes-iql':
            with open('../configs/distance_sa.yaml', 'r') as file:
                distance_dict = yaml.load(file, Loader=yaml.FullLoader)

            if self.env in distance_dict.keys():
                self.threshold_distance = self.threshold_coefficient * distance_dict[self.env]
            
        if 'antmaze' in self.env:
            self.actor_lr = 3e-4
            self.batch_size = 256
            self.beta = 10.0
            self.buffer_size = 2000000
            self.discount = 0.99
            self.iql_deterministic = False
            self.iql_tau = 0.9
            self.normalize = True
            self.normalize_reward = True
            self.qf_lr = 3e-4
            self.tau = 0.005
            self.vf_lr = 3e-4
            
        if 'door' in self.env:
            self.actor_lr = 3e-4
            self.actor_dropout = 0.1
            self.batch_size = 256
            self.beta = 3.0
            self.buffer_size = 2000000
            self.discount = 0.99
            self.iql_deterministic = False
            self.iql_tau = 0.8
            self.n_episodes = 10
            self.normalize = True
            self.normalize_reward = False
            self.qf_lr = 3e-4
            self.tau = 0.005
            self.vf_lr = 3e-4
            
        if 'pen' in self.env:
            self.actor_lr = 3e-4
            self.actor_dropout = 0.1
            self.batch_size = 256
            self.beta = 3.0
            self.buffer_size = 2000000
            self.discount = 0.99
            self.iql_deterministic = False
            self.iql_tau = 0.8
            self.n_episodes = 10
            self.normalize = True
            self.normalize_reward = False
            self.qf_lr = 3e-4
            self.tau = 0.005
            self.vf_lr = 3e-4
            
        if 'hammer' in self.env:
            self.actor_lr = 3e-4
            self.actor_dropout = 0.1
            self.batch_size = 256
            self.beta = 3.0
            self.buffer_size = 2000000
            self.discount = 0.99
            self.iql_deterministic = False
            self.iql_tau = 0.8
            self.n_episodes = 10
            self.normalize = True
            self.normalize_reward = False
            self.qf_lr = 3e-4
            self.tau = 0.005
            self.vf_lr = 3e-4

        if 'relocate' in self.env:
            self.actor_lr = 3e-4
            self.actor_dropout = 0.1
            self.batch_size = 256
            self.beta = 3.0
            self.buffer_size = 2000000
            self.discount = 0.99
            self.iql_deterministic = False
            self.iql_tau = 0.8
            self.n_episodes = 10
            self.normalize = True
            self.normalize_reward = False
            self.qf_lr = 3e-4
            self.tau = 0.005
            self.vf_lr = 3e-4