import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from torch.distributions.normal import Normal
import copy



class OUNoise:
    def __init__(self, size, seed, mu=0.0, theta=0.1, sigma=.1, sigma_min = 0.05, sigma_decay=.99):
        """Initialize parameters and noise process."""
        self.mu = mu * np.ones(size)
        self.theta = theta
        self.sigma = sigma
        self.sigma_min = sigma_min
        self.sigma_decay = sigma_decay
        self.seed = random.seed(seed)
        self.size = size
        self.reset()


    def reset(self):
        """Reset the internal state (= noise) to mean (mu)."""
        self.state = copy.copy(self.mu)
        """Resduce  sigma from initial value to min"""
        self.sigma = max(self.sigma_min, self.sigma*self.sigma_decay)

    def sample(self):
        """Update internal state and return it as a noise sample."""
        x = self.state
        dx = self.theta * (self.mu - x) + self.sigma * np.random.standard_normal(self.size)
        self.state = x + dx
        return self.state

def hard_update(target, source):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)

def soft_update(target, source, tau):
    with torch.no_grad():
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

class Actor(nn.Module):
    def __init__(self, num_features,n_actions, hidden_size=128,device = torch.device("cpu")):
        super(Actor, self).__init__()
        self.device = device
        self.fc1 = nn.Linear(num_features, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size,n_actions)
        self.to(device)
    def forward(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        act = torch.tanh(self.output_layer(h2))
        return act

class new_Actor(nn.Module):
    def __init__(self,args,rank, num_features,n_actions, hidden_size=128,device = torch.device("cpu")):
        super(new_Actor, self).__init__()
        self.args = args
        self.trainable_std = self.args.traiable_std
        self.device = device
        self.fc1 = nn.Linear(num_features, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size,n_actions)
        if self.trainable_std == True:
            self.logstd = nn.Parameter(torch.zeros(1, n_actions))
        self.noise = OUNoise(n_actions,rank)
        self.to(device)
    def forward(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        mu =torch.tanh(self.output_layer(h2))
        if self.trainable_std == True:
            std = torch.exp(self.logstd)
        else:
            logstd = torch.zeros_like(mu)
            std = torch.exp(logstd)
        return mu,std
    def action_with_exp(self,x):
        mu,sigma = self.forward(x)
        noise = torch.tensor(self.noise.sample()).to(self.device)
        new_mu = mu + noise
        dist = Normal(new_mu,sigma)
        action = dist.sample()
        return action
    def action_without_exp(self,x):
        mu,sigma = self.forward(x)
        return mu

def data_wrap(x,device):
    return torch.tensor(x,device=device,dtype=torch.float).detach()

class Critic(nn.Module):
    def __init__(self, num_features,n_actions, hidden_size=128,device = torch.device("cpu"),discrete_action=False):
        super(Critic, self).__init__()
        self.n_actions = n_actions
        self.device = device
        if discrete_action:
            self.fc1 = nn.Linear(num_features, hidden_size)
            self.fc2 = nn.Linear(hidden_size, hidden_size)
            self.output_layer = nn.Linear(hidden_size, n_actions)
        else:
            self.fc1 = nn.Linear(num_features + n_actions, hidden_size)
            self.fc2 = nn.Linear(hidden_size, hidden_size)
            self.output_layer = nn.Linear(hidden_size,1)
        self.to(device)
    def forward(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        v = self.output_layer(h2)
        return v
    def action_with_exp(self,x,epsilon):
        q_val = self.forward(x)
        action = torch.argmax(q_val, dim=-1)
        act_shape = action.shape
        eps_judge = np.random.uniform(act_shape)
        eps_judge = data_wrap(eps_judge,self.device)
        random_action = np.random.randint(self.n_actions,size = act_shape)
        random_action = data_wrap(random_action,self.device)
        eps_arr = epsilon * torch.ones_like(eps_judge)
        eps_flag = eps_judge < eps_arr
        ret_action = eps_flag * random_action + (1 - eps_flag)*action
        return ret_action
    def action_without_exp(self,x):
        q_val = self.forward(x)
        action = torch.argmax(q_val,dim=-1)
        return action




class Agent(object):
    def __init__(self,state_space,n_actions,n_ant,method,device,args):
        super(Agent, self).__init__()
        self.args = args
        self.device = device
        self.n_actions = n_actions
        self.n_ant = n_ant
        self.state_space = state_space
        self.actors = []
        self.critics = []
        self.actors_tar = []
        self.critics_tar = []
        self.method = method
        self.hql_alpha = args.hql_alpha
        if self.method == 'iddpg':
            for i in range(self.n_ant):
                if self.args.new_actor:
                    self.actors.append(new_Actor(args,i,self.state_space, self.n_actions, device=self.device))
                    self.actors_tar.append(new_Actor(args,i,self.state_space, self.n_actions, device=self.device))
                else:
                    self.actors.append(Actor(self.state_space,self.n_actions,device=self.device))
                    self.actors_tar.append(Actor(self.state_space, self.n_actions,device=self.device))
                self.critics.append(Critic(self.state_space,self.n_actions,device=self.device))
                self.critics_tar.append(Critic(self.state_space, self.n_actions,device=self.device))
            self.hard_update_policy()
            self.actor_optimizer = [torch.optim.Adam(self.actors[i].parameters(), lr=self.args.lr) for i in
                                    range(self.n_ant)]
        elif self.method == 'iql':
            for i in range(self.n_ant):
                self.critics.append(Critic(self.state_space, self.n_actions, device=self.device,discrete_action=True))
                self.critics_tar.append(Critic(self.state_space, self.n_actions, device=self.device,discrete_action=True))
        self.hard_update_critic()
        self.critic_optimizer = [torch.optim.Adam(self.critics[i].parameters(),lr=self.args.critic_lr) for i in range(self.n_ant)]


    def calc_actions(self,X,epsilon = None):
        X = data_wrap(X,self.device)
        actions = []
        if self.method == 'iddpg':
            for i in range(self.n_ant):
                if self.args.new_actor:
                    action_i = self.actors[i].action_with_exp(X[i])
                else:
                    action_i = self.actors[i](X[i])
                action_i = action_i.detach().cpu().numpy()
                actions.append(action_i)
            actions = np.array(actions)
        elif self.method == 'iql':
            for i in range(self.n_ant):
                if epsilon is None:
                    action_i = self.critics[i].action_without_exp(X[i])
                else:
                    action_i= self.critics[i].action_with_exp(X[i],epsilon)
                action_i = action_i.detach().cpu().numpy()
                actions.append(action_i)
            actions = np.array(actions)
        return actions
    def calc_next_Q(self,next_X,reward,done,gamma):
        # next_X = data_wrap(next_X, self.device)
        # reward = data_wrap(reward, self.device)
        # done = data_wrap(done, self.device)
        next_Q_list = []
        for i in range(self.n_ant):
            if self.method == 'iddpg':
                if self.args.new_actor:
                    action_i = self.actors_tar[i].action_without_exp(next_X[i])
                else:
                    action_i = self.actors_tar[i](next_X[i])
                input_i = torch.cat([next_X[i],action_i],dim=-1)
                next_Q_i = self.critics_tar[i](input_i)
            else:
                input_i = next_X[i]
                next_Q_list = self.critics_tar[i](input_i)
                next_Q_i = torch.max(next_Q_list,dim = -1, keepdim=True)[0]
            next_Q_i = next_Q_i.detach()
            next_Q_i = reward + next_Q_i * gamma * (1 - done)
            next_Q_list.append(next_Q_i)
        return next_Q_list

    def train_actors(self, X):
        # X = data_wrap(X,self.device)
        mean_policy_loss = 0
        for i in range(self.n_ant):
            policy_loss_i = self.train_actors_i(X[i],i)
            mean_policy_loss += policy_loss_i
        mean_policy_loss /= self.n_ant
        return mean_policy_loss
    def train_actors_i(self,X_i,i):
        if self.args.new_actor:
            action_i = self.actors[i].action_without_exp(X_i)
        else:
            action_i = self.actors[i](X_i)
        input_i = torch.cat([X_i,action_i],dim = -1)
        Q_i = self.critics[i](input_i)
        policy_loss_i = -Q_i.mean()
        self.actor_optimizer[i].zero_grad()
        policy_loss_i.backward()
        self.actor_optimizer[i].step()
        return policy_loss_i
    def train_critic(self, S, A, target_Q):
        # S = data_wrap(S,self.device)
        # A = data_wrap(A,self.device)
        mean_loss = 0
        for i in range(self.n_ant):
            loss_i = self.train_critic_i(S[i],A[i],target_Q[i],i)
            mean_loss += loss_i
        mean_loss /= self.n_ant
        return mean_loss
    def train_critic_i(self,S_i,A_i,target_Q_i,i):
        target_Q_i = data_wrap(target_Q_i,self.device)
        # print('S = {} A_i = {}'.format(S_i.shape,A_i.shape))
        if self.method == 'iddpg':
            input_i = torch.cat([S_i,A_i],dim = -1)
            Q_i = self.critics[i](input_i)
        elif self.method == 'iql':
            input_i = S_i
            Q_list = self.critics[i](input_i)
            Q_i = torch.max(Q_list, dim=-1, keepdim=True)[0]
        else:
            print('not implemented method !!!!')
            return 0

        if self.method == 'hddpg':
            A = (target_Q_i > Q_i)
            # print('target_Q_i = {} Q_i = {} A = {} A.dtype = {}'.format(target_Q_i.device,Q_i.device,A.device,A.dtype))
            compare_flag = A.type(torch.float)
            # print('flag = {} flag.dtype = {}'.format(compare_flag.device,compare_flag.dtype))
            weight = compare_flag * (1 - self.hql_alpha) + self.hql_alpha * torch.ones_like(compare_flag).to(self.device)
            weight = weight.detach()
            loss_i = weight*((Q_i - target_Q_i) ** 2)
            # print('weight = {} loss_i = {}'.format(weight.device,loss_i.device))
        else:
            loss_i = (Q_i - target_Q_i) ** 2
        loss_i = loss_i.mean()
        self.critic_optimizer[i].zero_grad()
        loss_i.backward()
        self.critic_optimizer[i].step()
        return loss_i

    def update(self):
        for i in range(self.n_ant):
            soft_update(self.critics_tar[i], self.critics[i],0.005)
            soft_update(self.actors_tar[i], self.actors[i], 0.005)

    def hard_update_policy(self):
        for i in range(self.n_ant):
            hard_update(self.actors_tar[i], self.actors[i])

    def hard_update_critic(self):
        for i in range(self.n_ant):
            hard_update(self.critics_tar[i], self.critics[i])


