######################## KOOPMAN ENCODER #####################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle
import os
import sys
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict
import json
import shutil
import torchvision
from termcolor import colored

import time
from skimage.util.shape import view_as_windows
from torch.utils.data import Dataset, DataLoader

import pickle
import ast 
import imageio

from rest_utils import *
# from sac_utils import *



def make_agent(obs_shape, action_shape, config, device):
    if config['agent']['name'] == 'curl_sac':
        return CurlSacAgent(
            obs_shape=obs_shape,
            action_shape=action_shape,
            device=device,
            config=config,
            hidden_dim=config['agent']['hidden_dim'],
            discount=config['agent']['discount'],
            init_temperature=config['agent']['init_temperature'],
            alpha_lr=config['agent']['alpha_lr'],
            alpha_beta=config['agent']['alpha_beta'],
            actor_lr=config['agent']['actor_lr'],
            actor_beta=config['agent']['actor_beta'],
            actor_log_std_min=config['agent']['actor_log_std_min'],
            actor_log_std_max=config['agent']['actor_log_std_max'],
            actor_update_freq=config['agent']['actor_update_freq'],
            critic_lr=config['agent']['critic_lr'],
            critic_beta=config['agent']['critic_beta'],
            critic_tau=config['agent']['critic_tau'],
            critic_target_update_freq=config['agent']['critic_target_update_freq'],
            encoder_type=config['env']['encoder_type'],
            encoder_feature_dim=config['agent']['encoder_feature_dim'],
            encoder_lr=config['agent']['encoder_lr'],
            encoder_tau=config['agent']['encoder_tau'],
            num_layers=config['agent']['num_layers'],
            num_filters=config['agent']['num_filters'],
            log_interval=config['log_interval'],
            detach_encoder=config['agent']['detach_encoder'],
            curl_latent_dim=config['agent']['curl_latent_dim']

        )
    elif config['agent']['name'] == 'curl_sac_koopmanlqr':
        return ComplexCurlSacKoopmanAgent(
            obs_shape=obs_shape,
            action_shape=action_shape,
            device=device,
            config=config,
            hidden_dim=config['agent']['hidden_dim'],
            discount=config['agent']['discount'],
            init_temperature=config['agent']['init_temperature'],
            alpha_lr=config['agent']['alpha_lr'],
            alpha_beta=config['agent']['alpha_beta'],
            actor_lr=config['agent']['actor_lr'],
            actor_beta=config['agent']['actor_beta'],
            actor_log_std_min=config['agent']['actor_log_std_min'],
            actor_log_std_max=config['agent']['actor_log_std_max'],
            actor_update_freq=config['agent']['actor_update_freq'],
            critic_lr=config['agent']['critic_lr'],
            critic_beta=config['agent']['critic_beta'],
            critic_tau=config['agent']['critic_tau'],
            critic_target_update_freq=config['agent']['critic_target_update_freq'],
            encoder_type=config['env']['encoder_type'],
            encoder_feature_dim=config['agent']['encoder_feature_dim'],
            encoder_lr=config['agent']['encoder_lr'],
            encoder_tau=config['agent']['encoder_tau'],
            num_layers=config['agent']['num_layers'],
            num_filters=config['agent']['num_filters'],
            log_interval=config['log_interval'],
            detach_encoder=config['agent']['detach_encoder'],
            curl_latent_dim=config['agent']['curl_latent_dim']
        )
    else:
        assert 'agent is not supported: %s' % config['agent']['name']



class KoopmanActor(nn.Module):
    """Koopman LQR actor."""
    def __init__(
        self, obs_shape, action_shape, hidden_dim, encoder_type,
        encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters,
        config
    ):
        super().__init__()

        self.encoder = make_encoder(
            encoder_type, obs_shape, encoder_feature_dim, num_layers,
            num_filters, output_logits=True
        )

        self.action_shape = action_shape
        self.config = config
        # XL: set up goal reference for different situations
        goal_meta = self.config['koopman']['koopman_goal_image_path']
        if isinstance(goal_meta, str) and goal_meta.endswith(".pkl"):
            with open(self.config['koopman']['koopman_goal_image_path'], "rb") as f:
                self.goal_obs = torch.from_numpy(pickle.load(f)).unsqueeze(0)
        elif isinstance(goal_meta, list):
            self.goal_obs = torch.from_numpy(np.array(self.config['koopman']['koopman_goal_image_path'], dtype=np.float32)).unsqueeze(0).to(torch.device(self.config['device']))
        else:
            self.goal_obs = None

        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        self.log_std_init = torch.nn.Parameter(torch.Tensor([1.0]).log()) # XL: initialize a log_std
        self.u_encode_dim = self.config['agent']['u_encode_dim']
        # XL: Koopman control module as trunk
        self.trunk = KoopmanLQR(k=encoder_feature_dim, 
                                T=5,
                                g_dim=encoder_feature_dim,
                                u_dim=action_shape[0],
                                g_goal=None,
                                u_affine=None,
                                device = torch.device(self.config['device']),
                                u_encode_dim=self.u_encode_dim,
                                config = self.config
                                )

        # self.koopman_encoder = MLP([encoder_feature_dim, encoder_feature_dim], activation=F.tanh).to(torch.device(self.config['device']))

        self.trunk.to(torch.device(self.config['device']))
        self.encoder.to(torch.device(self.config['device']))

        self.outputs = dict()
        self.apply(weight_init)
        self.device = torch.device(self.config['device'])


        

    def forward(
        self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False
    ):
        obs = self.encoder(obs, detach=detach_encoder)
        # obs = obs[:, :self.trunk._g_dim // 2] + 1j * obs[:, self.trunk._g_dim // 2:].to(self.device)
        # obs = obs.detach()
        # obs = self.koopman_encoder(obs)

        # XL: encode the goal images to be used in self.trunk
        if self.goal_obs is None:
            self.trunk._g_goal = torch.zeros((1, obs.shape[1])).squeeze(0).to(torch.device(self.config['device']))
        else:
            goal_obs = self.encoder(self.goal_obs.to(self.device), detach=detach_encoder)
            # goal_obs = self.koopman_encoder(goal_obs)
            goal_obs = goal_obs[:, :self.trunk._g_dim // 2] + 1j * goal_obs[:, self.trunk._g_dim // 2:].to(self.device)
            self.trunk._g_goal = goal_obs.squeeze(0)
            # print("goal_obs: ", goal_obs.shape)

        # XL: do not chunk to 2 parts as LQR directly gives mu; use constant log_std
        broadcast_shape = list(obs.shape[:-1]) + [self.action_shape[0]]
        # print(broadcast_shape, "broadcast_shape")
        # with torch.no_grad():
            # obs = obs[]
        # mu = 
        # # print("mu: ", mu)
        # mu, log_std = self.trunk(obs).chunk(1, dim=-1)[0], \
        #             self.log_std_init + torch.zeros(*broadcast_shape).to(torch.device(self.config['device']))
        
        mu = self.trunk(obs).chunk(1, dim=-1)[0]
        # print("mu: ", mu)
        mu, log_std = mu, \
                    self.log_std_init + torch.zeros(*broadcast_shape).to(torch.device(self.config['device']))
        
        # print("mu: ", mu)
        # print("log_std: ", log_std)

        # constrain log_std inside [log_std_min, log_std_max]
        log_std = torch.tanh(log_std)
        log_std = self.log_std_min + 0.5 * (
            self.log_std_max - self.log_std_min
        ) * (log_std + 1)

        self.outputs['mu'] = mu
        self.outputs['std'] = log_std.exp()

        if compute_pi:
            std = log_std.exp()
            noise = torch.randn_like(mu)
            pi = mu + noise * std
        else:
            pi = None
            entropy = None

        if compute_log_pi:
            log_pi = gaussian_logprob(noise, log_std)
        else:
            log_pi = None

        mu, pi, log_pi = squash(mu, pi, log_pi)
        # print("mu: ", mu)
        # print("pi: ", pi)
        # print("log_pi: ", log_pi)
        # print("log_std: ", log_std)
        return mu, pi, log_pi, log_std

    def log(self, L, step, log_freq=LOG_FREQ):
        if step % log_freq != 0:
            return

        for k, v in self.outputs.items():
            L.log_histogram('train_actor/%s_hist' % k, v, step)

        # L.log_param('train_actor/fc1', self.trunk[0], step)
        # L.log_param('train_actor/fc2', self.trunk[2], step)
        # L.log_param('train_actor/fc3', self.trunk[4], step)


class KoopmanCritic(Critic):
    def __init__(self, obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters):
        super(KoopmanCritic, self).__init__(obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters)




######################################## LQR Koopman ##################
def discretize(K, L, step):
    return torch.exp(step * K), (torch.exp(step * K) - 1) / K * L

class KoopmanLQR(nn.Module):
    def __init__(self, k, T, g_dim, u_dim, u_encode_dim, g_goal=None, u_affine=None, device = None, config = None):
        """
        k:          rank of approximated koopman operator
        T:          length of horizon
        g_dim:      dimension of latent state
        u_dim:      dimension of control input
        g_goal:     None by default. If not, override the x_goal so it is not necessarily corresponding to a concrete goal state
                    might be useful for non regularization tasks.  
        u_affine:   should be a linear transform for an augmented observation phi(x, u) = phi(x) + nn.Linear(u)
        """
        super().__init__()
        self._k = k
        self._T = T
        self._g_dim = g_dim
        self._u_dim = u_dim
        self._g_goal = g_goal.to(self.device) if g_goal is not None else None
        # print(g_goal.shape(), "hellos")
        self.action_emb_dim = u_encode_dim
        self.activation = F.relu
        self.state_dim = k
        self.device = device
        print(self.device, "device")
        self.action_encoder = MLP([self._u_dim, self.action_emb_dim*2], activation=self.activation).to(self.device)
        self.action_decoder = (MLP([self.action_emb_dim*2, self._u_dim], activation=F.relu)).to(self.device)
        


        
        # L Parameter 
        self.L = nn.Parameter(torch.normal(0, 0.1, size=(self.action_emb_dim, self.state_dim // 2, 2))).to(self.device)


        # Step parameter
        self.log_step = nn.Parameter(log_step_init(1)).to(self.device)
        self.step = torch.exp(self.log_step).to(self.device)
        self.real_init_type = config['koopman']['real_init_type']
        print(self.real_init_type, "real_init_type")
        self.real_init_value = -0.2
        self.im_init_type = config['koopman']['im_init_type']
        

        self.init_koopman_real_params()
        self.init_koopman_im_params()

        # Complex parameters
        # print(self.K_real.shape, self.K_im.shape)
        self.K_complex = torch.cat([self.K_real + 1j * self.K_im, self.K_real - 1j * self.K_im], dim=-1).to(self.device)
        print(self.K_complex.real, "K_complex")
        print(self.K_complex.shape, "K_complex")

        self.L_complex = self.L[:, :, 0] + 1j * self.L[:, :, 1].to(self.device)
        print(self.K_complex.shape, self.L_complex.shape)

        # Discretize (need to define discretize function based on your JAX code)
        self.K_dis, self.L_dis = discretize(K = self.K_complex, L= self.L_complex, step=self.step)

        print((self.K_dis), "K_dis")
        ############Complex 

        self._g_affine = torch.diag(self.K_dis).to(torch.complex64).to(self.device)
        # print(self._g_affine.shape, "g_Affijne")
        self._u_affine = self.L_dis.to(self.device).transpose(0, 1).to(torch.complex64).to(self.device)
        # print(self._u_affine.shape, "u_Affijne")

        
        # # prepare linear system params
        # self._g_affine = nn.Parameter(torch.empty((k, k)))
        
        # if self._u_affine is None:
        #     self._u_affine = nn.Parameter(torch.empty((k, u_dim)))
        # else:
        #     self._u_affine = nn.Parameter(self._u_affine)
        
        # # try to avoid degenerated case, can it be fixed with initialization?
        # torch.nn.init.normal_(self._g_affine, mean=0, std=1)
        # torch.nn.init.normal_(self._u_affine, mean=0, std=1)

        # parameters of quadratic functions
        self._q_diag_log = nn.Parameter(torch.zeros(self._k//2)) # to use: Q = diag(_q_diag_log.exp())
        self._r_diag_log = nn.Parameter(torch.zeros(self.action_emb_dim)) # gain of control penalty, in theory need to be parameterized...
        self._r_diag_log.requires_grad = False
        self._q_diag_log.requires_grad = True 
        self._q_diag_log.to(self.device)
        self._r_diag_log.to(self.device)
        # print(self._q_diag_log, "q_diag_log")
        # print(self._r_diag_log, "r_diag_log")
        # Q = torch.diag(self._q_diag_log.exp()).unsqueeze(0).to(torch.complex64)
        # R = torch.diag(self._r_diag_log.exp()).unsqueeze(0).to(torch.complex64)
        # print(Q, "Q")
        # print(R, "R")
        # zero tensor constant for k and v in the case of fixed origin
        # these will be automatically moved to gpu so no need to create and check in the forward process
        self.register_buffer('_zero_tensor_constant_k', torch.zeros((1, self._u_dim)))
        self.register_buffer('_zero_tensor_constant_v', torch.zeros((1, self._k)))

        # # we may need to create a few cache for K, k, V and v because they are not dependent on x
        # # unless we make g_goal depend on it. This allows to avoid repeatively calculate riccati recursion in eval mode
        self._riccati_solution_cache = None
        
        return

    def init_koopman_im_params(self):
        if self.im_init_type == 'increasing_freq':
            self.K_im = nn.Parameter(increasing_im_init((self.state_dim // 4,)), requires_grad=True).to(self.device)
            # self.K_im.requires_grad = True
        elif self.im_init_type == 'random':
            self.K_im = nn.Parameter(random_im_init((self.state_dim // 4,)), requires_grad=True).to(self.device)
            # self.K_im.requires_grad = True
        else:
            raise ValueError("Invalid imaginary initialization type")
    
    def init_koopman_real_params(self):
        if self.real_init_type == 'constant':
            self.K_real = torch.full((self.state_dim // 4,), self.real_init_value).to(self.device)
        elif self.real_init_type == 'learnable':
            self.K_real = nn.Parameter(torch.ones(self.state_dim // 4) * self.real_init_value, requires_grad=True).to(self.device)
            # self.K_r.requires_grad = True
            self.K_real = torch.clamp(self.K_real, -20, -0.01)
        else:
            raise ValueError("Invalid real initialization type")
    def forward(self, g0):
        '''
        perform mpc with current parameters given the initial x0
        '''

        g0 = g0[:, :self._g_dim // 2] + 1j * g0[:, self._g_dim // 2:].to(self.device)
        # print(g0.shape, "g0")
        K, k, V, v = self._retrieve_riccati_solution()
        # print(K[0].shape, g0.shape, "fejfe")
        u = -self._batch_mv(K[0], g0) + k[0]  # apply the first control as mpc
        # print(u, "U values")
        # print(-self._batch_mv(K[0], g0), "mul values")
        # print(k[0], "k[0] values")
        # print(u.shape, "fejfe")
        u = torch.cat([u.real, u.imag], -1).to(self.device)
        u = self.action_decoder(u).to(self.device)
        # print(u)
        # print(-F.tanh(u), "U values")
        # print(u, "U values")
        # print(u.shape, "U values")
        return -2*F.tanh(u)
    
    
    
    @staticmethod
    def _batch_mv(bmat, bvec):
      
        bmat = bmat.to(torch.complex64)
        bvec = bvec.to(torch.complex64)

        return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
    
    def _retrieve_riccati_solution(self):
        if self.training or self._riccati_solution_cache is None:
            Q = torch.diag(self._q_diag_log.exp()).unsqueeze(0).to(torch.complex64)
            R = torch.diag(self._r_diag_log.exp()).unsqueeze(0).to(torch.complex64)
            # print(Q, "Q")
            # print(R, "R")

            # use g_goal
            if self._g_goal is not None:
                goals = torch.repeat_interleave(self._g_goal.unsqueeze(0).unsqueeze(0), repeats=self._T+1, dim=1).to(self.device)
            else:
                goals = None

            # solve the lqr problem via a differentiable process.
            K, k, V, v = self._solve_lqr(self._g_affine.unsqueeze(0), self._u_affine.unsqueeze(0), Q, R, goals)
            self._riccati_solution_cache = (
                [tmp.detach().clone() for tmp in K], 
                [tmp.detach().clone() for tmp in k], 
                [tmp.detach().clone() for tmp in V], 
                [tmp.detach().clone() for tmp in v])
                 
        else:
            K, k, V, v = self._riccati_solution_cache
        # print(K, k, V, v)
        return K, k, V, v
    

    def _solve_lqr(self, A, B, Q, R, goals):
        # a differentiable process of solving LQR, 
        # time-invariant A, B, Q, R (with leading batch dimensions), but goals can be a batch of trajectories (batch_size, T+1, k)
        #       min \Sigma^{T} (x_t - goal[t])^T Q (x_t - goal[t]) + u_t^T R u_t
        # s.t.  x_{t+1} = A x_t + B u_t
        # return feedback gain and feedforward terms such that u = -K x + k

        T = self._T
        K = [None] * T
        k = [None] * T
        V = [None] * (T+1)
        v = [None] * (T+1)

        A_trans = A.transpose(-2,-1)
        B_trans = B.transpose(-2,-1)
        # print(A.shape, B.shape, Q.shape, R.shape, "fejfe")
        # print(A.shape, B.shape, Q.shape, R.shape, "fejfe")

        

        V[-1] = Q  # initialization for backpropagation
        if goals is not None:
            # print("goals", goals.shape)
            # print("Q", Q.shape)
            v[-1] = self._batch_mv(Q, goals[:, -1, :])
            for i in reversed(range(T)):
                # using torch.solve(B, A) to obtain the solution of AX = B to avoid direct inverse, note it also returns LU
                # for new torch.linalg.solve, no LU is returned
                # print(V[i+1].shape, B.shape, "fejfe")
                # print(torch.matmul(B_trans, V[i+1]).shape, "fejfe")

                V_uu_inv_B_trans = torch.linalg.solve(torch.matmul(torch.matmul(B_trans, V[i+1]), B) + R, B_trans)
                # print(V_uu_inv_B_trans.shape, V[i+1].shape, B.shape, R.shape,"fejfe")
                a = torch.matmul(V_uu_inv_B_trans, V[i+1])
                # print(a.shape)
                K[i] = torch.matmul(torch.matmul(V_uu_inv_B_trans, V[i+1]), A)
                k[i] = self._batch_mv(V_uu_inv_B_trans, v[i+1])

                # riccati difference equation, A-BK
                A_BK = A - torch.matmul(B, K[i])
                V[i] = torch.matmul(torch.matmul(A_trans, V[i+1]), A_BK) + Q
                v[i] = self._batch_mv(A_BK.transpose(-2, -1), v[i+1]) + self._batch_mv(Q, goals[:, i, :])
        else:
            # None goals means a fixed regulation point at origin. ignore k and v for efficiency
            for i in reversed(range(T)):
                # using torch.solve(B, A) to obtain the solution of A X = B to avoid direct inverse, note it also returns LU
                V_uu_inv_B_trans = torch.linalg.solve(torch.matmul(torch.matmul(B_trans, V[i+1]), B) + R, B_trans)
                K[i] = torch.matmul(torch.matmul(V_uu_inv_B_trans, V[i+1]), A)
                
                A_BK = A - torch.matmul(B, K[i]) #riccati difference equation: A-BK
                V[i] = torch.matmul(torch.matmul(A_trans, V[i+1]), A_BK) + Q
            k[:] = self._zero_tensor_constant_k
            v[:] = self._zero_tensor_constant_v       

        # we might need to cat or 
        #  to return them as tensors but for mpc maybe only the first time step is useful...
        # note K is for negative feedback, namely u = -Kx+k
        return K, k, V, v

    def _predict_koopman(self, G, U):
        '''
        predict dynamics with current koopman parameters
        note both input and return are embeddings of the predicted state, we can recover that by using invertible net, e.g. normalizing-flow models
        but that would require a same dimensionality
        '''
        # if u is 1d 
        if len(U.shape) == 1:
            U = U.unsqueeze(0)
        # print(G.shape, "G shape")
        G = G[:, :self._g_dim // 2] + 1j * G[:, self._g_dim // 2:].to(torch.complex64)
        # print(G.shape, "G shape")
        # print(G.shape, "G shape")
        # print(U, "prig control")

        u = self.action_encoder(U)
        # u = u.detach()
        # print(u.shape, "emc control")
        # print(u.shape, " emc control")
        U = u[ :, :self.action_emb_dim ] + 1j * u[:,self.action_emb_dim:]
        # print(U.shape, "emc control")
        # print(U.shape, "complex control")
        # print(G.shape, "prig control")
        # print(self._g_affine.shape, "control")
        # print(self._u_affine.shape, "control")
        # print(self._g_affine, "control")
        new_g = torch.matmul(G, self._g_affine.transpose(0, 1))+torch.matmul(U, self._u_affine.transpose(0, 1))
        # print(self._g_affine)
        # print(self._u_affine)
        
        # print(new_g.shape, "after com control")
        # print(torch.cat([new_g.real, new_g.imag], -1).shape, "after com control")
        # print(new_g.shape, "after com control")
        return torch.cat([new_g.real, new_g.imag], -1).to(self.device)

    def _control_loss(self, U):
        '''
        loss function for control
        '''


        # print(U.shape, "control")
        encoded_u = self.action_encoder(U) 
        # print(encoded_u.shape, "control")
        decoded_u = self.action_decoder(encoded_u)
        # print(decoded_u.shape, "control")
        return F.mse_loss(U, decoded_u) 
    






########################################### Complex Koopman  Agent ###################
            
    
    
class ComplexCurlSacKoopmanAgent(CurlSacAgent):
    def __init__(self, obs_shape, 
                 action_shape, 
                 device,
                 config,
                 hidden_dim=256,
                 discount=0.99,
                 init_temperature=0.01,
                 alpha_lr=1e-3,
                 alpha_beta=0.9,
                 actor_lr=1e-3,
                 actor_beta=0.9,
                 actor_log_std_min=-10,
                 actor_log_std_max=2,
                 actor_update_freq=2,
                 critic_lr=1e-3,critic_beta=0.9,
                 critic_tau=0.005,
                 critic_target_update_freq=2,
                 encoder_type='pixel',
                 encoder_feature_dim=128,
                 encoder_lr=1e-3,
                 encoder_tau=0.005,
                 num_layers=4,
                 num_filters=32,
                 cpc_update_freq=1,
                 log_interval=100,
                 detach_encoder=False,
                 curl_latent_dim=128):
        super(ComplexCurlSacKoopmanAgent, self).__init__(obs_shape, action_shape, device, config,
                                                  hidden_dim=hidden_dim,
                                                  discount=discount,
                                                  init_temperature=init_temperature,
                                                  alpha_lr=alpha_lr,
                                                  alpha_beta=alpha_beta,
                                                  actor_lr=actor_lr,
                                                  actor_beta=actor_beta,
                                                  actor_log_std_min=actor_log_std_min,
                                                  actor_log_std_max=actor_log_std_max,
                                                  actor_update_freq=actor_update_freq,
                                                  critic_lr=critic_lr,
                                                  critic_beta=critic_beta,
                                                  critic_tau=critic_tau,
                                                  critic_target_update_freq=critic_target_update_freq,
                                                  encoder_type=encoder_type,
                                                  encoder_feature_dim=encoder_feature_dim,
                                                  encoder_lr=encoder_lr,
                                                  encoder_tau=encoder_tau,
                                                  num_layers=num_layers,
                                                  num_filters=num_filters,
                                                  cpc_update_freq=cpc_update_freq,
                                                  log_interval=log_interval,
                                                  detach_encoder=detach_encoder,
                                                  curl_latent_dim=curl_latent_dim)

        print("encoder type:{}".format(encoder_type))

        self.koopman_update_freq = config['koopman']['koopman_update_freq']
        self.koopman_fit_coeff   = config['koopman']['koopman_fit_coeff']

        self.actor = KoopmanActor(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, actor_log_std_min, actor_log_std_max,
            num_layers, num_filters, config
        ).to(device)

        self.critic = KoopmanCritic(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters
        ).to(device)

        self.critic_target = KoopmanCritic(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters
        ).to(device)

        self.critic_target.load_state_dict(self.critic.state_dict())

        # tie encoders between actor and critic, and CURL and critic
        self.actor.encoder.copy_weights_from(self.critic.encoder)

        self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
        self.log_alpha.requires_grad = True
        # set target entropy to -|A|
        self.target_entropy = -np.prod(action_shape)
        
        # optimizers
        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999)
        )

        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999)
        )

        self.log_alpha_optimizer = torch.optim.Adam(
            [self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999)
        )

        # XL: additional optimizers
        self.koopman_optimizers = torch.optim.Adam(
            self.actor.trunk.parameters(), lr=0.001, betas=(actor_beta, 0.999)
        )

        if self.encoder_type in ['pixel','fc']:
            # create CURL encoder (the 128 batch size is probably unnecessary)
            self.CURL = CURL(obs_shape, encoder_feature_dim,
                        self.curl_latent_dim, self.critic,self.critic_target, output_type='continuous').to(self.device)

            # optimizer for critic encoder for reconstruction loss
            self.encoder_optimizer = torch.optim.Adam(
                self.critic.encoder.parameters(), lr=encoder_lr
            )

            self.cpc_optimizer = torch.optim.Adam(
                self.CURL.parameters(), lr=encoder_lr
            )
        self.cross_entropy_loss = nn.CrossEntropyLoss()

        self.train()
        self.critic_target.train()
    
    # TODO: check loss terms in koopmanlqr_sac_garage.py
    def update_actor_and_alpha(self, obs, next_obs, action, L, step):
        # detach encoder, so we don't update it with the actor loss
        _, pi, log_pi, log_std = self.actor(obs, detach_encoder=True)
        actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True)

        actor_Q = torch.min(actor_Q1, actor_Q2)
        actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()
        # print(actor_loss.mean(), "actor loss", actor_Q.mean(), "actor_Q", log_pi.mean(), "log_pi", self.alpha, "alpha")
        #print devices :
        # print("actor_loss device: ", actor_loss.device)
        # print("log_pi device: ", log_pi.device)
        # print("actor_Q device: ", actor_Q.device)
        # print("log_std device: ", log_std.device)
        # print("pi device: ", pi.device)
        # print("obs device: ", obs.device)
        # print("next_obs device: ", next_obs.device)
        # print("actor_optimizer device: ", self.actor_optimizer.device)

        # XL: [not useful] maybe add more loss terms for embedlqr
        # koopman_fit_loss = self.koopman_fit_loss(obs, next_obs, action, self.config['koopman']['least_square_fit_coeff'])
        # actor_loss += self.config['koopman']['koopman_fit_coeff'] * koopman_fit_loss
        # print("actor loss: {} || koopman fit loss: {}".format(actor_loss, koopman_fit_loss))

        if step % self.log_interval == 0:
            L.log('train_actor/loss', actor_loss, step)
            L.log('train_actor/target_entropy', self.target_entropy, step)

        entropy = 0.5 * log_std.shape[1] * \
            (1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1)
        
        if step % self.log_interval == 0:                                    
            L.log('train_actor/entropy', entropy.mean(), step)
  


        # g = self.actor.encoder(obs)
        # # g = self.actor.koopman_encoder(g)
        # g_next = self.actor.encoder(next_obs)
        # g_pred = self.actor.trunk._predict_koopman(g, action)
        # loss_fn = nn.MSELoss() 
        # # print("g_pred: ", g_pred.shape, "g_next: ", g_next.shape)
        # fit_loss = loss_fn(g_pred, g_next)   #+ self.actor.trunk._control_loss(action)
        # if step % self.log_interval == 0: 
        #     L.log('train/kpm_fitting_loss', fit_loss, step)

        # actor_loss += self.koopman_fit_coeff * fit_loss
        # optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward(retain_graph=True)
        self.actor_optimizer.step()

        self.actor.log(L, step)

        self.log_alpha_optimizer.zero_grad()
        alpha_loss = (self.alpha *
                      (-log_pi - self.target_entropy).detach()).mean()
        if step % self.log_interval == 0:
            L.log('train_alpha/loss', alpha_loss, step)
            L.log('train_alpha/value', self.alpha, step)
        
        alpha_loss.backward(retain_graph=True)
        self.log_alpha_optimizer.step() 

    # def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
    #     with torch.no_grad():
    #         _, policy_action, log_pi, _ = self.actor(next_obs)  # TODO: debug this actor, the policy_action has 1 more dimension
    #         target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
    #         target_V = torch.min(target_Q1,
    #                                 target_Q2) - self.alpha.detach() * log_pi
    #         target_Q = reward + (not_done * self.discount * target_V)

    #     # get current Q estimates
    #     current_Q1, current_Q2 = self.critic(
    #         obs, action, detach_encoder=True)
    #     critic_loss = F.mse_loss(current_Q1,
    #                                 target_Q) + F.mse_loss(current_Q2, target_Q)
    #     if step % self.log_interval == 0:
    #         L.log('train_critic/loss', critic_loss, step)


    #     # Optimize the critic
    #     self.critic_optimizer.zero_grad()
    #     critic_loss.backward()
    #     self.critic_optimizer.step()

    #     self.critic.log(L, step)


    def update_kpm(self, obs, next_obs, action, L, step, use_ls=False):
        # XL: we only use fit_loss for now, 
        # we do not use recon_loss as (1) curl is kind of recon, (2) we don't have a decoder spec.
        # we do not use reg_loss as it is not very important.
        g = self.actor.encoder(obs)
        # g = self.actor.koopman_encoder(g)
        g_next = self.actor.encoder(next_obs)
        # g = g.detach()
        # g_next = g_next.detach()
        # action = action.detach()
        # L   = L.detach()
        # step = step.detach()
    
        g_pred = self.actor.trunk._predict_koopman(g, action)
        g_next = g_next[:, :self.actor.trunk._g_dim // 2] + 1j * g_next[:, self.actor.trunk._g_dim // 2:].to(torch.device(self.config['device']))
        g_pred = g_pred[:, :self.actor.trunk._g_dim // 2] + 1j * g_pred[:, self.actor.trunk._g_dim // 2:].to(torch.device(self.config['device']))

        gt_mag = torch.sqrt(g_next.real**2 + g_next.imag**2)
        gp_mag = torch.sqrt(g_pred.real**2 + g_pred.imag**2)
        gt_ang = torch.atan2(g_next.imag, g_next.real)
        gp_ang = torch.atan2(g_pred.imag, g_pred.real)
        gt = torch.cat([gt_mag, gt_ang], dim=-1)
        gp = torch.cat([gp_mag, gp_ang], dim=-1)


        loss_fn = nn.MSELoss() 
        # print(g_pred)
        # print(g_next)
        # print(F.mse_loss(g_pred, g_next), "fit_loss")
        # print("g_pred: ", g_pred.shape, "g_next: ", g_next.shape)
        fit_loss = loss_fn(gp, gt) #+ self.actor.trunk._control_loss(action)
        
        if step % self.log_interval == 0: 
            L.log('train/kpm_fitting_loss', fit_loss, step)
        
        # XL: [not useful] update critic's encoder (should we update actor's encoder? maybe not because actor's encoder is tied to critic's)
        # self.encoder_optimizers.zero_grad()
        # fit_loss.backward()
        # self.encoder_optimizer.zero_grad()

        # update self.actor.trunk's parameters (A, B, Q, R)
        self.koopman_optimizers.zero_grad()
        fit_loss.backward(retain_graph=True)
        self.koopman_optimizers.step()


    def update(self, replay_buffer, L, step):
        if self.encoder_type in ['pixel','fc']:
            obs, action, reward, next_obs, not_done, cpc_kwargs = replay_buffer.sample_cpc()
        else:
            obs, action, reward, next_obs, not_done = replay_buffer.sample_proprio()
        # print(torch.max(action), "action")
        if step % self.log_interval == 0:
            L.log('train/batch_reward', reward.mean(), step)

        

        self.update_critic(obs, action, reward, next_obs, not_done, L, step)
        # print("update critic", step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(obs, next_obs, action, L, step)  # XL: fit the form of new update_actor_and_alpha()
            # print("update actor", step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(
                self.critic.Q1, self.critic_target.Q1, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.Q2, self.critic_target.Q2, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.encoder, self.critic_target.encoder,
                self.encoder_tau
            )
            # print("update critic target", step)
        
        if step % self.cpc_update_freq == 0 and self.encoder_type in ['pixel','fc']:
            obs_anchor, obs_pos = cpc_kwargs["obs_anchor"], cpc_kwargs["obs_pos"]
            self.update_cpc(obs_anchor, obs_pos,cpc_kwargs, L, step)
            # print("update cpc", step)
        
        if step % self.koopman_update_freq == 0 and self.encoder_type in ['pixel','fc'] and self.koopman_fit_coeff > 0:
            # print("update koopman", step)
            self.update_kpm(obs, next_obs, action, L, step)
        
        
        

                        # with utils.eval_mode(agent):


# class ComplexKoopmanSacAeAgent(SacAeAgent):
#     def __init__(self, obs_shape,
#         action_shape,
#         device,
#         config,
#         hidden_dim=256,
#         discount=0.99,
#         init_temperature=0.01,
#         alpha_lr=1e-3,
#         alpha_beta=0.9,
#         actor_lr=1e-3,
#         actor_beta=0.9,
#         actor_log_std_min=-10,
#         actor_log_std_max=2,
#         actor_update_freq=2,
#         critic_lr=1e-3,
#         critic_beta=0.9,
#         critic_tau=0.005,
#         critic_target_update_freq=2,
#         encoder_type='pixel',
#         encoder_feature_dim=50,
#         encoder_lr=1e-3,
#         encoder_tau=0.005,
#         decoder_type='pixel',
#         decoder_lr=1e-3,
#         decoder_update_freq=1,
#         decoder_latent_lambda=0.0,
#         decoder_weight_lambda=0.0,
#         num_layers=4,
#         num_filters=32):
#         super(KoopmanSacAeAgent, self).__init__(obs_shape, action_shape, device, config,
#                                                 hidden_dim=hidden_dim,
#                                                 discount=discount,
#                                                 init_temperature=init_temperature,
#                                                 alpha_lr=alpha_lr,
#                                                 alpha_beta=alpha_beta,
#                                                 actor_lr=actor_lr,
#                                                 actor_beta=actor_beta,
#                                                 actor_log_std_min=actor_log_std_min,
#                                                 actor_log_std_max=actor_log_std_max,
#                                                 actor_update_freq=actor_update_freq,
#                                                 critic_lr=critic_lr,
#                                                 critic_beta=critic_beta,
#                                                 critic_tau=critic_tau,
#                                                 critic_target_update_freq=critic_target_update_freq,
#                                                 encoder_type=encoder_type,
#                                                 encoder_feature_dim=encoder_feature_dim,
#                                                 encoder_lr=encoder_lr,
#                                                 encoder_tau=encoder_tau,
#                                                 decoder_type=decoder_type,
#                                                 decoder_lr=decoder_lr,
#                                                 decoder_update_freq=decoder_update_freq,
#                                                 decoder_latent_lambda=decoder_latent_lambda,
#                                                 decoder_weight_lambda=decoder_weight_lambda,
#                                                 num_layers=num_layers,
#                                                 num_filters=num_filters)

#         print("encoder type:{}".format(encoder_type))

#         self.koopman_update_freq = config['koopman']['koopman_update_freq']
#         self.koopman_fit_coeff   = config['koopman']['koopman_fit_coeff']

#         self.actor = KoopmanActor(
#             obs_shape, action_shape, hidden_dim, encoder_type,
#             encoder_feature_dim, actor_log_std_min, actor_log_std_max,
#             num_layers, num_filters, config).to(device)
        
#         self.critic = KoopmanCritic(
#             obs_shape, action_shape, hidden_dim, encoder_type,
#             encoder_feature_dim, num_layers, num_filters).to(device)
        
#         self.critic_target = KoopmanCritic(
#             obs_shape, action_shape, hidden_dim, encoder_type,
#             encoder_feature_dim, num_layers, num_filters).to(device)
        
    
#         # XL: additional optimizers
#         self.koopman_optimizers = torch.optim.Adam(
#             self.actor.trunk.parameters(), lr=actor_lr, betas=(actor_beta, 0.999))
#         self.critic_target.load_state_dict(self.critic.state_dict())

#         # tie encoders between actor and critic
#         self.actor.encoder.copy_weights_from(self.critic.encoder)

#         self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
#         self.log_alpha.requires_grad = True
#         # set target entropy to -|A|
#         self.target_entropy = -np.prod(action_shape)

#         self.decoder = None
#         if decoder_type != 'identity':
#             # create decoder
#             self.decoder = make_decoder(
#                 decoder_type, obs_shape, encoder_feature_dim, num_layers,
#                 num_filters
#             ).to(device)
#             self.decoder.apply(weight_init)

#             # optimizer for critic encoder for reconstruction loss
#             self.encoder_optimizer = torch.optim.Adam(
#                 self.critic.encoder.parameters(), lr=encoder_lr
#             )

#             # optimizer for decoder
#             self.decoder_optimizer = torch.optim.Adam(
#                 self.decoder.parameters(),
#                 lr=decoder_lr,
#                 weight_decay=decoder_weight_lambda
#             )

#         # optimizers
#         self.actor_optimizer = torch.optim.Adam(
#             self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999)
#         )

#         self.critic_optimizer = torch.optim.Adam(
#             self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999)
#         )

#         self.log_alpha_optimizer = torch.optim.Adam(
#             [self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999)
#         )

#         self.train()
#         self.critic_target.train()
