import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
from copy import deepcopy

import utils
import utilsmod
from encoder import make_encoder
import data_augs as rad 
from algorithms.rad_byol import RAD_BYOL


class RAD_BYOL_SharedProj(RAD_BYOL):
    """Shared Projector"""
    def __init__(
        self,
        obs_shape,
        action_shape,
        device,
        hidden_dim=256,
        discount=0.99,
        init_temperature=0.01,
        alpha_lr=1e-3,
        alpha_beta=0.9,
        actor_lr=1e-3,
        actor_beta=0.5,
        actor_log_std_min=-10,
        actor_log_std_max=2,
        actor_update_freq=2,
        critic_lr=1e-3,
        critic_beta=0.9,
        critic_tau=0.005,
        critic_target_update_freq=2,
        encoder_type='pixel',
        encoder_feature_dim=50,
        encoder_lr=1e-3,
        encoder_tau=0.05,
        num_layers=4,
        num_filters=32,
        cpc_update_freq=1,
        log_interval=100,
        detach_encoder=False,
        latent_dim=128,
        soda_tau=0.005,
        data_augs = ''
    ):
        self.device = device
        self.discount = discount
        self.critic_tau = critic_tau
        self.encoder_tau = encoder_tau
        self.actor_update_freq = actor_update_freq
        self.critic_target_update_freq = critic_target_update_freq
        self.cpc_update_freq = cpc_update_freq
        self.log_interval = log_interval
        self.image_size = obs_shape[-1]
        self.latent_dim = latent_dim
        self.detach_encoder = detach_encoder
        self.encoder_type = encoder_type
        self.data_augs = data_augs
        self.soda_tau = soda_tau

        self.augs_funcs = {}

        aug_to_func = {
                'crop':rad.random_crop,
                'grayscale':rad.random_grayscale,
                'cutout':rad.random_cutout,
                'cutout_color':rad.random_cutout_color,
                'flip':rad.random_flip,
                'rotate':rad.random_rotation,
                'rand_conv':rad.random_convolution,
                'color_jitter':rad.random_color_jitter,
                'translate':rad.random_translate,
                'no_aug':rad.no_aug,
            }

        for aug_name in self.data_augs.split('-'):
            assert aug_name in aug_to_func, 'invalid data aug string'
            self.augs_funcs[aug_name] = aug_to_func[aug_name]

        self.actor = utilsmod.Actor(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, actor_log_std_min, actor_log_std_max,
            num_layers, num_filters
        ).to(device)

        self.critic = utilsmod.Critic(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters
        ).to(device)

        self.critic_target = utilsmod.Critic(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters
        ).to(device)

        self.critic_target.load_state_dict(self.critic.state_dict())

        # add predictor and dynamic
        # predenc = utilsmod.PredEnc(self.critic.encoder.cnn, self.critic.encoder.cnn_out_dim[0],
        #     encoder_feature_dim, encoder_feature_dim)

        # utilsmod.Predictor is produced by PredMLP(2-layers)
        self.predictor = utilsmod.Predictor(self.critic.encoder, encoder_feature_dim, encoder_feature_dim).to(device)
        self.predictor_target = deepcopy(self.predictor).to(device)
        self.dynamic = utilsmod.Dynamic2Pred(encoder_feature_dim, action_shape[0]).to(device)
    
        # tie encoders between actor and critic, and CURL and critic
        self.actor.encoder.cnn = self.critic.encoder.cnn

        self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
        self.log_alpha.requires_grad = True
        # set target entropy to -|A|
        self.target_entropy = -np.prod(action_shape)
        
        # optimizers
        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999)
        )

        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999)
        )

        self.log_alpha_optimizer = torch.optim.Adam(
            [self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999)
        )

        if self.encoder_type == 'pixel':
            # optimizer for critic encoder for reconstruction loss
            self.predictor_optimizer = torch.optim.Adam(
                self.predictor.parameters(), lr=encoder_lr, betas=(critic_beta, 0.999)
            )

            self.dynamic_optimizer = torch.optim.Adam(
                self.dynamic.parameters(), lr=encoder_lr, betas=(critic_beta, 0.999)
            )

            # SGD optimizer
            self.pred_opt_SGD = torch.optim.SGD(
                self.predictor.parameters(), lr=encoder_lr, momentum=0.9, weight_decay=1e-4
            )

            self.dyn_opt_SGD = torch.optim.SGD(
                self.dynamic.parameters(), lr=encoder_lr, momentum=0.9, weight_decay=1e-4
            )

        self.cross_entropy_loss = nn.CrossEntropyLoss()

        self.train()
        self.critic_target.train()
