"""
highly based on https://github.com/alimama-tech/AuctionNet/blob/main/strategy_train_env/bidding_train_env/baseline/cql/cql.py#L5
"""
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import numpy as np
import torch
import os
from bidding_train_env.common.utils import normalize_state_test
import pickle


class Q(nn.Module):
    """
    IQL-Q net
    """

    def __init__(self, dim_observation, dim_action):
        super(Q, self).__init__()
        self.dim_observation = dim_observation
        self.dim_action = dim_action

        self.obs_FC = nn.Linear(self.dim_observation, 64)
        self.action_FC = nn.Linear(dim_action, 64)
        self.FC1 = nn.Linear(128, 64)
        self.FC2 = nn.Linear(64, 64)
        self.FC3 = nn.Linear(64, 1)

    # obs: batch_size * obs_dim
    def forward(self, obs, acts):
        obs_embedding = self.obs_FC(obs)
        action_embedding = self.action_FC(acts)
        embedding = torch.cat([obs_embedding, action_embedding], dim=-1)
        q = self.FC3(F.relu(self.FC2(F.relu(self.FC1(embedding)))))
        return q


class V(nn.Module):
    """
        IQL-V net
        """

    def __init__(self, dim_observation):
        super(V, self).__init__()
        self.FC1 = nn.Linear(dim_observation, 128)
        self.FC2 = nn.Linear(128, 64)
        self.FC3 = nn.Linear(64, 32)
        self.FC4 = nn.Linear(32, 1)

    # obs: batch_size * obs_dim
    def forward(self, obs):
        result = F.relu(self.FC1(obs))
        result = F.relu(self.FC2(result))
        result = F.relu(self.FC3(result))
        return self.FC4(result)




class CQL_Critic(nn.Module):
    """
    CQL critic:
    Bellman TD + conservative regularization
    """

    def __init__(
        self,
        state_dim,
        act_dim,
        gamma=0.99,
        tau=0.01,
        V_lr=1e-4,
        critic_lr=1e-4,
        network_random_seed=1,
        expectile=0.7,      # Still using expectile update for value_net
        temperature=3.0,    # Not used here, keeping the parameter for now
        cql_alpha=1.0,      # CQL regularization weight alpha
        num_random=10,      # Number of random actions sampled per state
        min_action=-1.0,    # Lower bound of action space (adjust based on your environment)
        max_action=1.0      # Upper bound of action space
    ):
        super().__init__()
        self.num_of_states = state_dim
        self.num_of_actions = act_dim
        self.V_lr = V_lr
        self.critic_lr = critic_lr
        self.network_random_seed = network_random_seed
        self.expectile = expectile
        self.temperature = temperature

        # CQL-related hyperparameters
        self.cql_alpha = cql_alpha
        self.num_random = num_random
        self.min_action = min_action
        self.max_action = max_action

        torch.random.manual_seed(self.network_random_seed)

        # Value network (using the structure from IQL)
        self.value_net = V(self.num_of_states)

        # Double Q networks + target
        self.critic1 = Q(self.num_of_states, self.num_of_actions)
        self.critic2 = Q(self.num_of_states, self.num_of_actions)
        self.critic1_target = Q(self.num_of_states, self.num_of_actions)
        self.critic1_target.load_state_dict(self.critic1.state_dict())
        self.critic2_target = Q(self.num_of_states, self.num_of_actions)
        self.critic2_target.load_state_dict(self.critic2.state_dict())

        self.GAMMA = gamma
        self.tau = tau

        self.value_optimizer = Adam(self.value_net.parameters(), lr=self.V_lr)
        self.critic1_optimizer = Adam(self.critic1.parameters(), lr=self.critic_lr)
        self.critic2_optimizer = Adam(self.critic2.parameters(), lr=self.critic_lr)

        self.deterministic_action = True
        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.critic1.cuda()
            self.critic2.cuda()
            self.critic1_target.cuda()
            self.critic2_target.cuda()
            self.value_net.cuda()
        self.FloatTensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor

    def take_critics(self, state, action, normalize_indices):
        self.eval()
        default_critic_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))), "saved_model", "CQL_critic")
        norm_path = os.path.join(default_critic_path, "normalize_dict.pkl")
        with open(norm_path, "rb") as f:
            normalize_dict = pickle.load(f)
        state_norm = normalize_state_test(state, normalize_dict, normalize_indices)
        state_norm = torch.tensor(state_norm, dtype=torch.float32).to(action.device)
        assess_values = self.get_critic(state_norm, action)
        return assess_values

    def forward(self, state, action):
        """
        Return target Q1, Q2 (for direct use of target networks externally)
        """
        q_target_1 = self.critic1_target(state, action)
        q_target_2 = self.critic2_target(state, action)
        return [q_target_1, q_target_2]

    def get_critic(self, state, action):
        # Use current critic (not target) for action Q-value evaluation
        q1 = self.critic1(state, action)
        q2 = self.critic2(state, action)
        return [q1, q2]

    def step(self, states, actions, rewards, next_states, dones, update=True):
        """
        Train model
        """

        if update:
            # Update value network
            self.value_optimizer.zero_grad()
            value_loss = self.calc_value_loss(states, actions)
            value_loss.backward()
            self.value_optimizer.step()

            # Update Q networks
            critic1_loss, critic2_loss = self.calc_q_loss(
                states, actions, rewards, dones, next_states
            )
            self.critic1_optimizer.zero_grad()
            critic1_loss.backward()
            self.critic1_optimizer.step()

            self.critic2_optimizer.zero_grad()
            critic2_loss.backward()
            self.critic2_optimizer.step()

            # soft update target
            self.update_target(self.critic1, self.critic1_target)
            self.update_target(self.critic2, self.critic2_target)
        else:
            value_loss = self.calc_value_loss(states, actions)
            critic1_loss, critic2_loss = self.calc_q_loss(
                states, actions, rewards, dones, next_states
            )

        return (
            critic1_loss.cpu().data.numpy(),
            critic2_loss.cpu().data.numpy(),
            value_loss.cpu().data.numpy(),
        )

    # --------- V loss use expectile regression ---------
    def calc_value_loss(self, states, actions):
        with torch.no_grad():
            q1 = self.critic1_target(states, actions)
            q2 = self.critic2_target(states, actions)
            min_Q = torch.min(q1, q2)

        value = self.value_net(states)
        value_loss = self.l2_loss(min_Q - value, self.expectile).mean()
        return value_loss

    # --------- CQL  Q-loss ---------
    def calc_q_loss(self, states, actions, rewards, dones, next_states):
        """
        TD(Bellman) + CQL conservative :
        L_Q = MSE(Q - y) + alpha * ( CQL1 + CQL2 )
         CQL = logsumexp(Q(s, a')) - Q(s, a_data)
        a' random sample。
        """
        device = states.device
        batch_size = states.shape[0]

        #  y = r + gamma * (1-done) * V(s')
        with torch.no_grad():
            next_v = self.value_net(next_states)
            q_target = rewards + (self.GAMMA * (1.0 - dones) * next_v)

        q1 = self.critic1(states, actions)
        q2 = self.critic2(states, actions)

        bellman_loss1 = ((q1 - q_target) ** 2).mean()
        bellman_loss2 = ((q2 - q_target) ** 2).mean()

        
        random_actions = torch.empty(
            batch_size * self.num_random, self.num_of_actions, device=device
        ).uniform_(self.min_action, self.max_action) # [B * num_random, act_dim]

        states_repeat = (
            states.unsqueeze(1)
            .repeat(1, self.num_random, 1)
            .reshape(batch_size * self.num_random, -1)
        )

        q1_rand = self.critic1(states_repeat, random_actions).view(
            batch_size, self.num_random, 1
        )
        q2_rand = self.critic2(states_repeat, random_actions).view(
            batch_size, self.num_random, 1
        )

        q1_data = q1.view(batch_size, 1, 1)
        q2_data = q2.view(batch_size, 1, 1)
        cat_q1 = torch.cat([q1_rand, q1_data], dim=1)
        cat_q2 = torch.cat([q2_rand, q2_data], dim=1)

        cql1_loss = (torch.logsumexp(cat_q1, dim=1) - q1_data.squeeze(-1)).mean()
        cql2_loss = (torch.logsumexp(cat_q2, dim=1) - q2_data.squeeze(-1)).mean()

        # Q-loss = TD + alpha * CQL
        critic1_loss = bellman_loss1 + self.cql_alpha * cql1_loss
        critic2_loss = bellman_loss2 + self.cql_alpha * cql2_loss

        return critic1_loss, critic2_loss

    def update_target(self, local_model, target_model):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(
                (1.0 - self.tau) * target_param.data + self.tau * local_param.data
            )

    def save_net(self, save_path):
        if not os.path.isdir(save_path):
            os.makedirs(save_path)
        checkpoint = {
            "model_state_dict": self.state_dict(),
            "critic1_optimizer": self.critic1_optimizer.state_dict(),
            "critic2_optimizer": self.critic2_optimizer.state_dict(),
            "value_optimizer": self.value_optimizer.state_dict(),
        }
        file_path = os.path.join(save_path, "cql_critic.pt")
        torch.save(checkpoint, file_path)

    def save_jit(self, save_path):
        if not os.path.isdir(save_path):
            os.makedirs(save_path)
        jit_model = torch.jit.script(self.cpu())
        torch.jit.save(jit_model, f"{save_path}/cql_model.pth")

    def load_net(self, load_path="saved_model/CQL_critic/cql_critic.pt", device="cuda:0"):
        checkpoint = torch.load(load_path, map_location=device)
        self.load_state_dict(checkpoint["model_state_dict"])
        self.critic1_optimizer.load_state_dict(checkpoint["critic1_optimizer"])
        self.critic2_optimizer.load_state_dict(checkpoint["critic2_optimizer"])
        self.value_optimizer.load_state_dict(checkpoint["value_optimizer"])
        print(f"CQL Checkpoint loaded on {device}")

    def l2_loss(self, diff, expectile=0.8):
        weight = torch.where(diff > 0, expectile, (1 - expectile))
        return weight * (diff ** 2)


if __name__ == '__main__':
    model = CQL_Critic()
    step_num = 100
    batch_size = 1000
    for i in range(step_num):
        states = np.random.uniform(2, 5, size=(batch_size, 3))
        next_states = np.random.uniform(2, 5, size=(batch_size, 3))
        actions = np.random.uniform(-1, 1, size=(batch_size, 1))
        rewards = np.random.uniform(0, 1, size=(batch_size, 1))
        terminals = np.zeros((batch_size, 1))
        states, next_states, actions, rewards, terminals = torch.tensor(states, dtype=torch.float), torch.tensor(
            next_states, dtype=torch.float), torch.tensor(actions, dtype=torch.float), torch.tensor(rewards,
                                                                                                    dtype=torch.float), torch.tensor(
            terminals, dtype=torch.float)
        q_loss, v_loss, = model.step(states, actions, rewards, next_states, terminals)
        print(f'step:{i} q_loss:{q_loss} v_loss:{v_loss}')
