import torch, torch.nn as nn, torch.nn.functional as F
import numpy as np

class Actor(nn.Module):
    def __init__(self,s_dim,a_dim,h=256):
        super().__init__()
        self.fc1=nn.Linear(s_dim,h); self.fc2=nn.Linear(h,h)
        self.mu=nn.Linear(h,a_dim); self.log_std=nn.Linear(h,a_dim)
    def forward(self,s,deterministic=False):
        x=F.relu(self.fc1(s)); x=F.relu(self.fc2(x))
        mu=self.mu(x); log_std=torch.clamp(self.log_std(x),-20,2)
        std=log_std.exp()
        dist=torch.distributions.Normal(mu,std)
        z=dist.rsample() if not deterministic else mu
        a=torch.tanh(z)
        logp=dist.log_prob(z)-torch.log(1-a.pow(2)+1e-6)
        return a, logp.sum(-1,keepdim=True)

class Critic(nn.Module):
    def __init__(self,s_dim,a_dim,h=256):
        super().__init__()
        self.q1=nn.Sequential(nn.Linear(s_dim+a_dim,h),nn.ReLU(),
                              nn.Linear(h,h),nn.ReLU(),nn.Linear(h,1))
        self.q2=nn.Sequential(nn.Linear(s_dim+a_dim,h),nn.ReLU(),
                              nn.Linear(h,h),nn.ReLU(),nn.Linear(h,1))
    def forward(self,s,a):
        sa=torch.cat([s,a],-1)
        return self.q1(sa),self.q2(sa)

# ---------- Replay ----------
class Replay:
    def __init__(self,s_dim,a_dim,size=1000000):
        self.s=np.zeros((size,s_dim),np.float32)
        self.a=np.zeros((size,a_dim),np.float32)
        self.r=np.zeros((size,1),np.float32)
        self.s2=np.zeros((size,s_dim),np.float32)
        self.d=np.zeros((size,1),np.float32)
        self.ptr=0; self.size=0; self.max=size
    def store(self,s,a,r,s2,d):
        i=self.ptr
        self.s[i],self.a[i],self.r[i],self.s2[i],self.d[i]=s,a,r,s2,d
        self.ptr=(self.ptr+1)%self.max; self.size=min(self.size+1,self.max)
    def sample(self,batch=256):
        idx=np.random.randint(0,self.size,size=batch)
        return dict(s=self.s[idx],a=self.a[idx],r=self.r[idx],s2=self.s2[idx],d=self.d[idx])

# ---------- SAC ----------
class SAC:
    def __init__(self,s_dim,a_dim,lr=3e-4,gamma=0.99,tau=0.005,device="cpu"):
        self.actor=Actor(s_dim,a_dim).to(device)
        self.critic=Critic(s_dim,a_dim).to(device)
        self.critic_t=Critic(s_dim,a_dim).to(device)
        self.critic_t.load_state_dict(self.critic.state_dict())
        self.opt_a=torch.optim.Adam(self.actor.parameters(),lr=lr)
        self.opt_c=torch.optim.Adam(self.critic.parameters(),lr=lr)
        self.log_alpha=torch.tensor(0.0,requires_grad=True,device=device)
        self.opt_alpha=torch.optim.Adam([self.log_alpha],lr=lr)
        self.target_entropy=-a_dim
        self.gamma=gamma; self.tau=tau; self.device=device
    @property
    def alpha(self): return self.log_alpha.exp()
    @torch.no_grad()
    def act(self,s,deterministic=False):
        s=torch.FloatTensor(s).unsqueeze(0).to(self.device)
        a,logp=self.actor(s)
        return a[0].cpu().numpy()
    def update(self,buf,batch=256):
        d=buf.sample(batch)
        s=torch.FloatTensor(d["s"]).to(self.device)
        a=torch.FloatTensor(d["a"]).to(self.device)
        r=torch.FloatTensor(d["r"]).to(self.device)
        s2=torch.FloatTensor(d["s2"]).to(self.device)
        d_=torch.FloatTensor(d["d"]).to(self.device)

        # Critic
        with torch.no_grad():
            a2,logp2=self.actor(s2)
            q1_t,q2_t=self.critic_t(s2,a2)
            q_t=torch.min(q1_t,q2_t)-self.alpha*logp2
            y=r+(1-d_)*self.gamma*q_t
        q1,q2=self.critic(s,a)
        c_loss=F.mse_loss(q1,y)+F.mse_loss(q2,y)
        self.opt_c.zero_grad(); c_loss.backward(); self.opt_c.step()

        # Actor
        a_new,logp=self.actor(s)
        q1_pi,q2_pi=self.critic(s,a_new)
        q_pi=torch.min(q1_pi,q2_pi)
        a_loss=(self.alpha*logp - q_pi).mean()
        self.opt_a.zero_grad(); a_loss.backward(); self.opt_a.step()

        # Alpha
        alpha_loss=-(self.log_alpha*(logp+self.target_entropy).detach()).mean()
        self.opt_alpha.zero_grad(); alpha_loss.backward(); self.opt_alpha.step()

        # Soft update
        with torch.no_grad():
            for p,pt in zip(self.critic.parameters(),self.critic_t.parameters()):
                pt.data.mul_(1-self.tau); pt.data.add_(self.tau*p.data)
        return float(c_loss.item()), float(a_loss.item()), float(alpha_loss.item())