

# source: https://github.com/sfujim/TD3_BC
# https://arxiv.org/pdf/2106.06860.pdf
import copy
import os
import random
import uuid
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb



from torch.distributions import Normal
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import TanhTransform


TensorBatch = List[torch.Tensor]

MEAN_MIN = -9.0
MEAN_MAX = 9.0
LOG_STD_MIN = -5
LOG_STD_MAX = 2
LOG_PI_NORM_MAX = 10
LOG_PI_NORM_MIN = -20

EPS = 1e-7


# @dataclass
# class TrainConfig:
#     # Experiment
#     device: str = "cuda"
#     env: str = "halfcheetah-medium-expert-v2"  # OpenAI gym environment name
#     seed: int = 0  # Sets Gym, PyTorch and Numpy seeds
#     eval_freq: int = int(5e3)  # How often (time steps) we evaluate
#     n_episodes: int = 10  # How many episodes run during evaluation
#     max_timesteps: int = int(1e6)  # Max time steps to run environment
#     checkpoints_path: Optional[str] = None  # Save path
#     load_model: str = ""  # Model load file name, "" doesn't load
#     # TD3
#     buffer_size: int = 2_000_000  # Replay buffer size
#     batch_size: int = 256  # Batch size for all networks
#     discount: float = 0.99  # Discount ffor
#     expl_noise: float = 0.1  # Std of Gaussian exploration noise
#     tau: float = 0.005  # Target network update rate
#     policy_noise: float = 0.2  # Noise added to target actor during critic update
#     noise_clip: float = 0.5  # Range to clip target actor noise
#     policy_freq: int = 2  # Frequency of delayed actor updates
#     # TD3 + BC
#     alpha: float = 2.5  # Coefficient for Q function in actor loss
#     normalize: bool = True  # Normalize states
#     normalize_reward: bool = False  # Normalize reward
#     # Wandb logging
#     project: str = "CORL"
#     group: str = "TD3_BC-D4RL"
#     name: str = "TD3_BC"

#     def __post_init__(self):
#         self.name = f"{self.name}-{self.env}-{str(uuid.uuid4())[:8]}"
#         if self.checkpoints_path is not None:
#             self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)


def soft_update(target: nn.Module, source: nn.Module, tau: float):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)


def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
    mean = states.mean(0)
    std = states.std(0) + eps
    return mean, std



# class Actor(nn.Module):
#     def __init__(self, state_dim: int, action_dim: int, max_action: float):
#         super(Actor, self).__init__()

#         self.net = nn.Sequential(
#             nn.Linear(state_dim, 256),
#             nn.ReLU(),
#             nn.Linear(256, 256),
#             nn.ReLU(),
#             nn.Linear(256, action_dim),
#             nn.Tanh(),
#         )

#         self.max_action = max_action

#     def forward(self, state: torch.Tensor) -> torch.Tensor:
#         return self.max_action * self.net(state)

#     @torch.no_grad()
#     def act(self, state: np.ndarray, device: str = "cpu") -> np.ndarray:
#         state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32)
#         return self(state).cpu().data.numpy().flatten()

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Actor, self).__init__()

        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.mu_head = nn.Linear(256, action_dim)
        self.sigma_head = nn.Linear(256, action_dim)

    def _get_outputs(self, state):
        a = F.relu(self.fc1(state))
        a = F.relu(self.fc2(a))
        mu = self.mu_head(a)
        mu = torch.clip(mu, MEAN_MIN, MEAN_MAX)
        log_sigma = self.sigma_head(a)
        log_sigma = torch.clip(log_sigma, LOG_STD_MIN, LOG_STD_MAX)
        sigma = torch.exp(log_sigma)

        a_distribution = TransformedDistribution(
            Normal(mu, sigma), TanhTransform(cache_size=1)
        )
        a_tanh_mode = torch.tanh(mu)
        return a_distribution, a_tanh_mode

    def forward(self, state):
        a_dist, a_tanh_mode = self._get_outputs(state)
        action = a_dist.rsample()
        logp_pi = a_dist.log_prob(action).sum(axis=-1)
        return action, logp_pi, a_tanh_mode

    def get_log_density(self, state, action):
        a_dist, _ = self._get_outputs(state)
        action_clip = torch.clip(action, -1. + EPS, 1. - EPS)
        logp_action = a_dist.log_prob(action_clip)
        return logp_action
    
    @torch.no_grad()
    def select_action(self, state, device):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        _, _, action = self.forward(state)
        return action.cpu().data.numpy().flatten()
    
    def act(self, state):

        # print(state)
        
        _, _, action = self.forward(state)

        # print(action)

        # # print(state.shape)
        # print(action.shape)
        # exit(0)
        return action



class Critic(nn.Module):
    def __init__(self, state_dim: int, action_dim: int):
        super(Critic, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
        )

    def forward(self, state, action):


        # print("state:    ", state)
        # print("action:   ", action)
        sa = torch.cat([state, action], dim=1)
        return self.net(sa)


class TD3_BC:
    def __init__(
        self,
        state_dim: int, 
        action_dim: int, 
        max_action: float, 
        # actor: nn.Module,
        # actor_optimizer: torch.optim.Optimizer,
        lr=3e-4, 
        wd=5e-3,
        discount: float = 0.99,
        tau: float = 0.005,
        policy_noise: float = 0.2,
        noise_clip: float = 0.5,
        policy_freq: int = 2,
        bc_freq: int = 2,
        alpha: float = 2.5,
        device: str = "cpu",
    ):

        actor = Actor(state_dim, action_dim).to(device)
        actor_optimizer = torch.optim.Adam(actor.parameters(), lr=lr, weight_decay=wd)

        critic_1 = Critic(state_dim, action_dim).to(device)
        critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=lr)
        critic_2 = Critic(state_dim, action_dim).to(device)
        critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=lr)

        self.actor = actor
        self.actor_target = copy.deepcopy(actor)
        self.actor_optimizer = actor_optimizer
        self.critic_1 = critic_1
        self.critic_1_target = copy.deepcopy(critic_1)
        self.critic_1_optimizer = critic_1_optimizer
        self.critic_2 = critic_2
        self.critic_2_target = copy.deepcopy(critic_2)
        self.critic_2_optimizer = critic_2_optimizer

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.bc_freq = bc_freq
        self.alpha = alpha

        self.total_it = 0
        self.device = device

    def train(self, buffer_e, buffer_o, iterations=1, batch_size=256):

        self.actor.train()

        log_dict = {"actor_loss": [], "critic_loss": [], "actor_loss_q": [], "actor_loss_bc": []}

        for iter_i in range(iterations):
            self.total_it += 1
        
            # batch_data, batch_belong = buffer.sample(batch_size, return_belong=True)
            # state, action, _, _, _, _ = batch_data

            
            # state, action, _, _, _, _ = buffer.sample(batch_size)

            state_e, action_e, next_state_e, reward_e, not_done_e, _ = buffer_e.sample(batch_size)

            state_o, action_o, next_state_o, reward_o, not_done_o, _ = buffer_o.sample(batch_size)

            state = torch.cat([state_e, state_o], dim=0)
            action = torch.cat([action_e, action_o], dim=0)
            next_state = torch.cat([next_state_e, next_state_o], dim=0)
            reward = torch.cat([reward_e, reward_o], dim=0)
            not_done = torch.cat([not_done_e, not_done_o], dim=0)

            # not_done = 1 - done

            with torch.no_grad():
                # Select action according to actor and add clipped noise
                # noise = (torch.randn_like(action) * self.policy_noise).clamp(
                #     -self.noise_clip, self.noise_clip
                # )

                # next_action = (self.actor_target(next_state) + noise).clamp(
                #     -self.max_action, self.max_action
                # )


                perturbed_actions, _, _ = self.actor_target(next_state)

                next_action = perturbed_actions.clamp(-self.max_action, self.max_action)


                # Compute the target Q value
                target_q1 = self.critic_1_target(next_state, next_action)
                target_q2 = self.critic_2_target(next_state, next_action)
                target_q = torch.min(target_q1, target_q2)
                target_q = reward + not_done * self.discount * target_q

            # Get current Q estimates
            current_q1 = self.critic_1(state, action)
            current_q2 = self.critic_2(state, action)

            # Compute critic loss
            critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
            
            log_dict["critic_loss"].append(critic_loss.item())

            # Optimize the critic
            self.critic_1_optimizer.zero_grad()
            self.critic_2_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_1_optimizer.step()
            self.critic_2_optimizer.step()

            actor_loss = 0.

            # Delayed actor updates
            
            
            if self.total_it % self.bc_freq == 0:
                log_pi_e = self.actor.get_log_density(state_e, action_e)

                bckl_loss = -torch.sum(log_pi_e, 1)

                lmbda2 = 1. / bckl_loss.abs().mean().detach()

                log_dict["actor_loss_bc"].append(bckl_loss.mean().item())

                actor_loss_bc = bckl_loss.mean() # * lmbda2

                actor_loss += actor_loss_bc

                # actor_loss = bckl_loss.mean()


            if self.total_it % self.policy_freq == 0:
                # Compute actor loss
                # pi = self.actor(state)
                
                _, _, pi = self.actor(state)

                q = self.critic_1(state, pi)

                lmbda = 1. / q.abs().mean().detach()


                # actor_loss = -lmbda * q.mean() + bckl_loss.mean()

                if self.total_it % self.bc_freq == 0:
                    lmbda /= lmbda2

                log_dict["actor_loss_q"].append(-q.mean().item())

                actor_loss_q =  -lmbda * q.mean()


                actor_loss += self.alpha * actor_loss_q

            
            log_dict["actor_loss"].append(actor_loss.item())

            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # # Update the frozen target models
            soft_update(self.critic_1_target, self.critic_1, self.tau)
            soft_update(self.critic_2_target, self.critic_2, self.tau)
            soft_update(self.actor_target, self.actor, self.tau)

        return log_dict

    def state_dict(self) -> Dict[str, Any]:
        return {
            "critic_1": self.critic_1.state_dict(),
            "critic_1_optimizer": self.critic_1_optimizer.state_dict(),
            "critic_2": self.critic_2.state_dict(),
            "critic_2_optimizer": self.critic_2_optimizer.state_dict(),
            "actor": self.actor.state_dict(),
            "actor_optimizer": self.actor_optimizer.state_dict(),
            "total_it": self.total_it,
        }

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.critic_1.load_state_dict(state_dict["critic_1"])
        self.critic_1_optimizer.load_state_dict(state_dict["critic_1_optimizer"])
        self.critic_1_target = copy.deepcopy(self.critic_1)

        self.critic_2.load_state_dict(state_dict["critic_2"])
        self.critic_2_optimizer.load_state_dict(state_dict["critic_2_optimizer"])
        self.critic_2_target = copy.deepcopy(self.critic_2)

        self.actor.load_state_dict(state_dict["actor"])
        self.actor_optimizer.load_state_dict(state_dict["actor_optimizer"])
        self.actor_target = copy.deepcopy(self.actor)

        self.total_it = state_dict["total_it"]
    
    def save(self, filename):
        torch.save(self.critic_1.state_dict(), filename + "_critic_1")
        torch.save(self.critic_1_optimizer.state_dict(), filename + "_critic_1_optimizer")
        torch.save(self.critic_2.state_dict(), filename + "_critic_2")
        torch.save(self.critic_2_optimizer.state_dict(), filename + "_critic_2_optimizer")

        torch.save(self.actor.state_dict(), filename + "_actor")
        torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")


    def load(self, filename):
        if not torch.cuda.is_available():
            self.critic_1.load_state_dict(torch.load(filename + "_critic_1", map_location=torch.device('cpu')))
            self.critic_1_optimizer.load_state_dict(torch.load(filename + "_critic_1_optimizer", map_location=torch.device('cpu')))
            self.critic_1_target = copy.deepcopy(self.critic_1)

            self.critic_2.load_state_dict(torch.load(filename + "_critic_2", map_location=torch.device('cpu')))
            self.critic_2_optimizer.load_state_dict(torch.load(filename + "_critic_2_optimizer", map_location=torch.device('cpu')))
            self.critic_2_target = copy.deepcopy(self.critic_2)

            self.actor.load_state_dict(torch.load(filename + "_actor", map_location=torch.device('cpu')))
            self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer", map_location=torch.device('cpu')))
            self.actor_target = copy.deepcopy(self.actor)

        else:
            self.critic_1.load_state_dict(torch.load(filename + "_critic_1"))
            self.critic_1_optimizer.load_state_dict(torch.load(filename + "_critic_1_optimizer"))
            self.critic_1_target = copy.deepcopy(self.critic_1)

            self.critic_2.load_state_dict(torch.load(filename + "_critic_2"))
            self.critic_2_optimizer.load_state_dict(torch.load(filename + "_critic_2_optimizer"))
            self.critic_2_target = copy.deepcopy(self.critic_2)


            self.actor.load_state_dict(torch.load(filename + "_actor"))
            self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
            self.actor_target = copy.deepcopy(self.actor)



