"""
highly based on https://github.com/alimama-tech/AuctionNet/blob/main/strategy_train_env/bidding_train_env/baseline/iql/iql.py#L8
"""
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import numpy as np
import torch
import os
from bidding_train_env.common.utils import normalize_state_test
import pickle


class Q(nn.Module):
    """
    IQL-Q net
    """

    def __init__(self, dim_observation, dim_action):
        super(Q, self).__init__()
        self.dim_observation = dim_observation
        self.dim_action = dim_action

        self.obs_FC = nn.Linear(self.dim_observation, 64)
        self.action_FC = nn.Linear(dim_action, 64)
        self.FC1 = nn.Linear(128, 64)
        self.FC2 = nn.Linear(64, 64)
        self.FC3 = nn.Linear(64, 1)

    # obs: batch_size * obs_dim
    def forward(self, obs, acts):
        obs_embedding = self.obs_FC(obs)
        action_embedding = self.action_FC(acts)
        embedding = torch.cat([obs_embedding, action_embedding], dim=-1)
        q = self.FC3(F.relu(self.FC2(F.relu(self.FC1(embedding)))))
        return q


class V(nn.Module):
    """
        IQL-V net
        """

    def __init__(self, dim_observation):
        super(V, self).__init__()
        self.FC1 = nn.Linear(dim_observation, 128)
        self.FC2 = nn.Linear(128, 64)
        self.FC3 = nn.Linear(64, 32)
        self.FC4 = nn.Linear(32, 1)

    # obs: batch_size * obs_dim
    def forward(self, obs):
        result = F.relu(self.FC1(obs))
        result = F.relu(self.FC2(result))
        result = F.relu(self.FC3(result))
        return self.FC4(result)




class IQL_Critic(nn.Module):
    """
    IQL model
    """

    def __init__(self, state_dim,act_dim,  gamma=0.99, tau=0.01, V_lr=1e-4, critic_lr=1e-4,
                 network_random_seed=1, expectile=0.7, temperature=3.0):
        super().__init__()
        self.num_of_states = state_dim
        self.num_of_actions = act_dim
        self.V_lr = V_lr
        self.critic_lr = critic_lr
        self.network_random_seed = network_random_seed
        self.expectile = expectile
        self.temperature = temperature
        torch.random.manual_seed(self.network_random_seed)
        self.value_net = V(self.num_of_states)
        self.critic1 = Q(self.num_of_states, self.num_of_actions)
        self.critic2 = Q(self.num_of_states, self.num_of_actions)
        self.critic1_target = Q(self.num_of_states, self.num_of_actions)
        self.critic1_target.load_state_dict(self.critic1.state_dict())
        self.critic2_target = Q(self.num_of_states, self.num_of_actions)
        self.critic2_target.load_state_dict(self.critic2.state_dict())
        self.GAMMA = gamma
        self.tau = tau
        self.value_optimizer = Adam(self.value_net.parameters(), lr=self.V_lr)
        self.critic1_optimizer = Adam(self.critic1.parameters(), lr=self.critic_lr)
        self.critic2_optimizer = Adam(self.critic2.parameters(), lr=self.critic_lr)
        self.deterministic_action = True
        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.critic1.cuda()
            self.critic2.cuda()
            self.critic1_target.cuda()
            self.critic2_target.cuda()
            self.value_net.cuda()
        self.FloatTensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor

    
    def take_critics(self,state,action,normalize_indices):
        self.eval()
        # Usage method
        default_critic_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))), "saved_model", "IQL_critic")
        norm_path = os.path.join(default_critic_path, "normalize_dict.pkl")
        with open(norm_path, "rb") as f:
            normalize_dict = pickle.load(f)
        state_norm =  normalize_state_test(state, normalize_dict, normalize_indices)
        state_norm = torch.tensor(state_norm, dtype=torch.float32).to(action.device)
        assess_values =self.get_critic(state_norm,action)
        return assess_values



    def forward(self, state, action):
        """
        Return Q1, Q2 instead of action
        """
        q_target_1 = self.critic1_target(state, action)
        q_target_2 = self.critic2_target(state, action)
        return_value = [q_target_1, q_target_2]
        return  return_value
    
    def step(self, states, actions, rewards, next_states, dones):
        """
        Train model
        """

        self.value_optimizer.zero_grad()
        value_loss = self.calc_value_loss(states, actions)
        value_loss.backward()
        self.value_optimizer.step()


        critic1_loss, critic2_loss = self.calc_q_loss(states, actions, rewards, dones, next_states)
        self.critic1_optimizer.zero_grad()
        critic1_loss.backward()
        self.critic1_optimizer.step()
        self.critic2_optimizer.zero_grad()
        critic2_loss.backward()
        self.critic2_optimizer.step()

        self.update_target(self.critic1, self.critic1_target)
        self.update_target(self.critic2, self.critic2_target)

        return critic1_loss.cpu().data.numpy(),critic2_loss.cpu().data.numpy(), value_loss.cpu().data.numpy()
    
    def get_critic(self,state,action):
        return_value = self.forward(state,action)
        return return_value

    def calc_value_loss(self, states, actions):
        with torch.no_grad():
            q1 = self.critic1_target(states, actions)
            q2 = self.critic2_target(states, actions)
            min_Q = torch.min(q1, q2)

        value = self.value_net(states)
        value_loss = self.l2_loss(min_Q - value, self.expectile).mean()
        return value_loss

    def calc_q_loss(self, states, actions, rewards, dones, next_states):
        with torch.no_grad():
            next_v = self.value_net(next_states)
            q_target = rewards + (self.GAMMA * (1 - dones) * next_v)

        q1 = self.critic1(states, actions)
        q2 = self.critic2(states, actions)
        critic1_loss = ((q1 - q_target) ** 2).mean()
        critic2_loss = ((q2 - q_target) ** 2).mean()
        return critic1_loss, critic2_loss

    def update_target(self, local_model, target_model):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_((1. - self.tau) * target_param.data + self.tau * local_param.data)


  

    def save_net(self, save_path):
        if not os.path.isdir(save_path):
            os.makedirs(save_path)
        checkpoint = {
            "model_state_dict": self.state_dict(),
            "critic1_optimizer": self.critic1_optimizer.state_dict(),
            "critic2_optimizer": self.critic2_optimizer.state_dict(),
            "value_optimizer": self.value_optimizer.state_dict(),
        }
        file_path = os.path.join(save_path, "iql_critic.pt")
        torch.save(checkpoint, file_path)
    
    def save_jit(self, save_path):
        if not os.path.isdir(save_path):
            os.makedirs(save_path)
        jit_model = torch.jit.script(self.cpu())
        torch.jit.save(jit_model, f'{save_path}/iql_model.pth')

    def load_net(self, load_path="saved_model/IQL_critic/iql_critic.pt", device='cuda:0'):
        checkpoint = torch.load(load_path, map_location=device)
        self.load_state_dict(checkpoint["model_state_dict"])
        self.critic1_optimizer.load_state_dict(checkpoint["critic1_optimizer"])
        self.critic2_optimizer.load_state_dict(checkpoint["critic2_optimizer"])
        self.value_optimizer.load_state_dict(checkpoint["value_optimizer"])
        print(f"Checkpoint loaded on {device}")



    def l2_loss(self, diff, expectile=0.8):
        weight = torch.where(diff > 0, expectile, (1 - expectile))
        return weight * (diff ** 2)


if __name__ == '__main__':
    model = IQL_Critic()
    step_num = 100
    batch_size = 1000
    for i in range(step_num):
        states = np.random.uniform(2, 5, size=(batch_size, 3))
        next_states = np.random.uniform(2, 5, size=(batch_size, 3))
        actions = np.random.uniform(-1, 1, size=(batch_size, 1))
        rewards = np.random.uniform(0, 1, size=(batch_size, 1))
        terminals = np.zeros((batch_size, 1))
        states, next_states, actions, rewards, terminals = torch.tensor(states, dtype=torch.float), torch.tensor(
            next_states, dtype=torch.float), torch.tensor(actions, dtype=torch.float), torch.tensor(rewards,
                                                                                                    dtype=torch.float), torch.tensor(
            terminals, dtype=torch.float)
        q_loss, v_loss, = model.step(states, actions, rewards, next_states, terminals)
        print(f'step:{i} q_loss:{q_loss} v_loss:{v_loss}')
