import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math

import time

from algs.rep_learn import kron, weight_init, Feature, Discriminator


class RepLearn(object):
    """SAC+AE algorithm."""
    def __init__(
        self,
        obs_dim,
        state_dim,
        action_dim,
        hidden_dim,
        num_update,
        num_feature_update,
        num_adv_update,
        device,
        discriminator_lr=1e-3,
        discriminator_beta=0.9,
        feature_lr=1e-3,
        feature_beta=0.9,
        weight_lr=1e-3,
        weight_beta=0.9, 
        batch_size = 128,
        lamb = 0.1,
        tau = 1
    ):

        self.obs_dim = obs_dim
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim

        self.feature_dim = state_dim * action_dim
        
        self.device = device

        self.lamb = lamb

        self.num_feature_update = num_feature_update
        self.num_adv_update = num_adv_update
        self.num_update = num_update

        self.batch_size = batch_size

        self.phi = Feature(obs_dim, action_dim, device, tau=tau).to(device)
        self.phi_tilde = Feature(obs_dim, action_dim, device, tau=tau).to(device)

        # self.discriminators = Discriminator(obs_dim, hidden_dim).to(device)

        self.phi_optimizer = torch.optim.Adam(
            self.phi.parameters(), lr=feature_lr, betas=(feature_beta, 0.999)
        )
        self.phi_tilde_optimizer = torch.optim.Adam(
            self.phi_tilde.parameters(), lr=feature_lr, betas=(feature_beta, 0.999)
        )

        # self.dis_optimizer = torch.optim.Adam(
        #     self.discriminators.parameters(), lr=self.discriminator_lr, betas=(self.discriminator_beta, 0.999)
        # )

        self.feature_lr = feature_lr
        self.feature_beta = feature_beta

        self.discriminator_lr = discriminator_lr
        self.discriminator_beta = discriminator_beta

        self.weight_lr = weight_lr

        self.discriminators = Discriminator(self.obs_dim, self.hidden_dim, self.num_update).to(self.device)
        self.dis_optimizer = torch.optim.Adam(
            self.discriminators.parameters(), lr=self.discriminator_lr, betas=(self.discriminator_beta, 0.999)
        )


    def feature_learning(self, replay_buffer, T):
        #W = torch.rand((T+1, self.action_dim * self.state_dim), requires_grad=True).to(self.device)

        # w_optimizer = torch.optim.Adam(
        #     W, lr=self.feature_lr, betas=(self.feature_beta, 0.999)
        # )
        total_loss = 0

        self.phi.reset_weights(T)
        self.phi_optimizer = torch.optim.Adam(
            self.phi.parameters(), lr=self.feature_lr, betas=(self.feature_beta, 0.999)
        )

        for i in range(self.num_feature_update):
            obs, actions, rewards, next_obs = replay_buffer.sample(batch_size=self.batch_size)

            loss = 0
            #for t in range(T+1):
            with torch.no_grad():
                dis_out = self.discriminators.get_till(next_obs, T)

            out = self.phi.predict(obs,actions)
            #print(out.shape)
            assert out.shape == dis_out.shape

                # feature = self.phi(obs, actions)
                # out = torch.matmul(feature, W[t])
            loss = F.mse_loss(out, dis_out)
                #w_optimizer.zero_grad()
            self.phi_optimizer.zero_grad()
            loss.backward()
                #w_optimizer.step()
                # print(W[t].grad)
                # W[t] = W[t] - self.weight_lr * W[t].grad
            self.phi_optimizer.step()

            total_loss += loss.item()

            #print(loss.item())
            #print(W)
        #return W
        return loss.item()

    def adv_learning(self, replay_buffer, T):

        total_loss = 0

        for i in range(self.num_adv_update):
            obs, actions, rewards, next_obs = replay_buffer.sample(batch_size=self.batch_size)

            with torch.no_grad():
                dis_out = self.discriminators.get_one(next_obs,T)

                feature = self.phi(obs, actions)
                Sigma = torch.matmul(feature.T, feature) + self.lamb * torch.eye(self.feature_dim).to(self.device)

                feature_tilde = self.phi_tilde(obs, actions)
                Sigma_tilde = torch.matmul(feature_tilde.T, feature_tilde) + self.lamb * torch.eye(self.feature_dim).to(self.device)

                w = torch.matmul(torch.inverse(Sigma), torch.sum(torch.mul(feature,dis_out),0))
                w_tilde = torch.matmul(torch.inverse(Sigma_tilde), torch.sum(torch.mul(feature_tilde,dis_out),0))

            dis_out = self.discriminators.get_one(next_obs,T).squeeze()
            # print(dis_out)
            # feature_tilde = self.phi_tilde(obs, actions)

            phi_out = torch.matmul(feature, w)
            phi_tilde_out = torch.matmul(feature_tilde, w_tilde)

            loss = F.mse_loss(phi_tilde_out, dis_out) - F.mse_loss(phi_out, dis_out)
            #loss = - F.mse_loss(phi_out, dis_out)

            self.dis_optimizer.zero_grad()
            self.phi_tilde_optimizer.zero_grad()

            loss.backward()

            self.dis_optimizer.step()
            self.phi_tilde_optimizer.step()

            #print(loss.item())
            total_loss += loss.item()

        return loss.item()


        #self.discriminators.append(cur_discriminator)


    def update(self, replay_buffer):
        #obs, actions, rewards, next_obs = replay_buffer.get_full()
        #self.discriminators = Discriminator(self.obs_dim, self.hidden_dim, self.num_update).to(self.device)
        #self.dis_optimizer = torch.optim.Adam(
        #    self.discriminators.parameters(), lr=self.discriminator_lr, betas=(self.discriminator_beta, 0.999)
        #)
        function_start_time = time.time()
        self.discriminators.apply(weight_init)

        feature_losses = []
        adv_losses = []

        for t in range(self.num_update):
            #print("!!!!!!!")
            #print(t)
            start_time = time.time()
            feature_loss = self.feature_learning(replay_buffer, t)
            feature_losses.append(time.time() - start_time)
            start_time = time.time()
            adv_loss = self.adv_learning(replay_buffer, t)
            adv_losses.append(time.time() - start_time)

        return time.time() - function_start_time, np.sum(adv_losses)












