import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math

from algs.base_learner import kron, weight_init, Feature, Discriminator, BaseLearner

class RepLearn(BaseLearner):
    """SAC+AE algorithm."""
    def __init__(self,
                obs_dim,
                state_dim,
                action_dim,
                hidden_dim,
                num_update,
                num_feature_update,
                num_adv_update,
                device, 
                **kwargs):

        super().__init__(obs_dim,
                        state_dim,
                        action_dim,
                        hidden_dim,
                        num_update,
                        num_feature_update,
                        num_adv_update,
                        device, 
                        **kwargs)


    def adv_learning(self, replay_buffer, T):

        total_loss = 0

        for i in range(self.num_adv_update):
            obs, actions, rewards, next_obs = replay_buffer.sample(batch_size=self.batch_size)
            obs2, actions2, rewards2, next_obs2 = replay_buffer.sample(batch_size=self.batch_size)

            with torch.no_grad():
                #dis_out = self.discriminators.get_one(next_obs,T)

                feature = self.phi(obs, actions)
                Sigma = torch.matmul(feature.T, feature) + self.lamb * torch.eye(self.feature_dim).to(self.device)

                feature2 = self.phi(obs2, actions2)

            dis_out = self.discriminators.get_one(next_obs,T).squeeze()
            dis_out2 = self.discriminators.get_one(next_obs,T)

            #print((feature2 * dis_out2).shape)
            #print(torch.sum(feature2 * dis_out2, 0).shape)

            target = torch.matmul(torch.matmul(feature, torch.inverse(Sigma)),torch.sum(feature2 * dis_out2, 0))

            #print(target.shape)

            loss = F.mse_loss(dis_out, target) 

            self.dis_optimizer.zero_grad()

            loss.backward()

            self.dis_optimizer.step()

            #print(loss.item())
            total_loss += loss.item()

        return loss.item()






