import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math

from algs.base_learner import kron, weight_init, Feature, Discriminator, BaseLearner

class RepLearn(BaseLearner):
    def __init__(self,
                obs_dim,
                state_dim,
                action_dim,
                hidden_dim,
                num_update,
                num_feature_update,
                num_adv_update,
                device, 
                **kwargs):

        super().__init__(obs_dim,
                        state_dim,
                        action_dim,
                        hidden_dim,
                        num_update,
                        num_feature_update,
                        num_adv_update,
                        device, 
                        **kwargs)


    def adv_learning(self, replay_buffer, T):
        
        # self.phi_tilde.reset_weights(T)
        # self.phi_tilde_optimizer = torch.optim.SGD(
        #     self.phi_tilde.parameters(), lr=self.feature_lr, weight_decay=0.01
        # )
        loss_list = []
        
        total_loss = 0

        for i in range(self.num_adv_update):
            obs, actions, rewards, next_obs = replay_buffer.sample(batch_size=self.batch_size)
            obs2, actions2, rewards2, next_obs2 = replay_buffer.sample(batch_size=self.batch_size)

            with torch.no_grad():
                dis_out = self.discriminators.get_one(next_obs,T)
                dis_out2 = self.discriminators.get_one(next_obs2,T)

                feature = self.phi(obs, actions)
                feature2 = self.phi(obs2, actions2)

                Sigma = torch.matmul(feature.T, feature) + self.lamb * torch.eye(self.feature_dim).to(self.device)

                

            with torch.no_grad():
                dis_out = self.discriminators.get_one(next_obs,T)
            
            phi_tilde_out = self.phi_tilde.predict(obs, actions)
            tilde_loss = F.mse_loss(phi_tilde_out, dis_out) + torch.norm(self.phi_tilde.weights.weight)

            self.phi_tilde_optimizer.zero_grad()
            tilde_loss.backward()

            self.phi_tilde_optimizer.step()

            with torch.no_grad():
                phi_tilde_out = self.phi_tilde.predict(obs, actions).squeeze()
            
            dis_out = self.discriminators.get_one(next_obs,T).squeeze()
            dis_out2 = self.discriminators.get_one(next_obs2,T)
            
            #target = torch.matmul(torch.matmul(feature, torch.inverse(Sigma)),torch.sum(feature2 * dis_out2, 0))
            target = torch.matmul(feature, torch.linalg.solve(Sigma,torch.sum(feature2 * dis_out2, 0)))

            # print(dis_out)
            # feature_tilde = self.phi_tilde(obs, actions)
            # phi_tilde_out = torch.matmul(feature_tilde, w_tilde)
            
            #print(phi_tilde_out.shape)
            #print(dis_out.shape)

            dis_loss = F.mse_loss(phi_tilde_out, dis_out) - F.mse_loss(target, dis_out)
            #loss = - F.mse_loss(phi_out, dis_out)

            self.dis_optimizer.zero_grad()
            #self.phi_tilde_optimizer.zero_grad()

            dis_loss.backward()

            self.dis_optimizer.step()
            #self.phi_tilde_optimizer.step()

            loss_list.append(-dis_loss.item())
            

            total_loss += dis_loss.item() + tilde_loss.item()

        return loss_list















