from smuco.stc import update_stc
from smuco.stc import ReplayBuffer as MVCReplayBuffer 
from sac_utils.logx import EpochLogger

import logging
import os
from copy import deepcopy
import itertools
import numpy as np
import time
import torch
from torch.optim import Adam
from torch.distributions.normal import Normal
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal

log = logging.getLogger(__name__)

LOG_STD_MAX = 2
LOG_STD_MIN = -20

def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])

def random_crop_single(img, out):
    """
    Crop single image
    """
    c, h, w = img.shape
    crop_max = h - out + 1
    w1 = np.random.randint(0, crop_max)
    h1 = np.random.randint(0, crop_max)
    cropped = img[:, h1:h1+out, w1:w1+out]
    return cropped


class MLPActor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):
        super().__init__()
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.act_limit = act_limit

    def forward(self, obs, deterministic=False, with_logprob=True):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()

        if with_logprob:
            logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
            logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
        else:
            logp_pi = None

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        return pi_action, logp_pi


class MLPQFunction(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape.


class MLPActorCritic(nn.Module):

    def __init__(self, feature_dim, action_space, hidden_sizes=(256,256),
                 activation=nn.ReLU):
        super().__init__()

        # obs_dim = observation_space.shape[0]
        act_dim = action_space.shape[0]
        act_limit = action_space.high[0]

        # build policy and value functions
        self.pi = MLPActor(feature_dim, act_dim, hidden_sizes, activation, act_limit)
        self.q1 = MLPQFunction(feature_dim, act_dim, hidden_sizes, activation)
        self.q2 = MLPQFunction(feature_dim, act_dim, hidden_sizes, activation)

    def act(self, obs, deterministic=False):
        with torch.no_grad():
            a, _ = self.pi(obs, deterministic, False)
            return a.numpy()


class ReplayBuffer:
    """
    A simple FIFO experience replay buffer for SAC agents.
    """

    def __init__(self, obs_dim, act_dim, size):
        self.obs_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32)
        self.obs2_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros(combined_shape(size, act_dim), dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size

    def store(self, obs, act, rew, next_obs, done):
        self.obs_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr+1) % self.max_size
        self.size = min(self.size+1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        batch = dict(obs=self.obs_buf[idxs],
                     obs2=self.obs2_buf[idxs],
                     act=self.act_buf[idxs],
                     rew=self.rew_buf[idxs],
                     done=self.done_buf[idxs])
        return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in batch.items()}

def sac(cfg, env_fn, exp, ac_kwargs=dict(), logger_kwargs=dict()):
    """
    Soft Actor-Critic (SAC)
    """
    seed = cfg.seed
    feature_dim = cfg.encoder.feature_dim
    gamma = cfg.sac.gamma
    alpha = cfg.sac.alpha
    lr = cfg.sac.lr
    polyak = cfg.sac.polyak
    num_test_episodes = cfg.sac.num_test_episodes
    max_ep_len = cfg.sac.max_ep_len
    epochs = cfg.sac.epochs
    steps_per_epoch = cfg.sac.steps_per_epoch
    start_steps = cfg.sac.start_steps
    update_every = cfg.sac.update_every
    update_after = cfg.sac.update_after
    save_freq = cfg.sac.save_freq
    batch_size = cfg.sac.batch_size

    logger = EpochLogger(**logger_kwargs)
    # logger.save_config(locals())

    torch.manual_seed(seed)
    np.random.seed(seed)

    env, test_env = env_fn(), env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape[0]

    act_limit = env.action_space.high[0]

    ac = MLPActorCritic(feature_dim, env.action_space, **ac_kwargs)
    ac_targ = deepcopy(ac)

    for p in ac_targ.parameters():
        p.requires_grad = False
        
    q_params = itertools.chain(ac.q1.parameters(), ac.q2.parameters())

    # Experience buffer: store representation
    replay_buffer = ReplayBuffer(obs_dim=feature_dim, act_dim=act_dim, size=cfg.sac.replay_size)

    var_counts = tuple(count_vars(module) for module in [ac.pi, ac.q1, ac.q2])
    logger.log('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n'%var_counts)

    num_views = cfg.replay_buffer.num_views
    C, H, W = env.observation_space.shape
    out = cfg.crop_out_dim
    model = exp.mm_vae
    # Replay buffer: stores multi-views observations
    mvc_replay_buffer = MVCReplayBuffer(obs_shape=(C, out, out), action_shape=(act_dim, ), **cfg.replay_buffer)
    mvc_replay_buffer_dir = os.path.join(logger_kwargs['output_dir'], 'mvc_rb')
    if not os.path.exists(mvc_replay_buffer_dir):
        os.makedirs(mvc_replay_buffer_dir)
    exp.replay_buffer = mvc_replay_buffer

    def compute_loss_q(data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']

        q1 = ac.q1(o,a)
        q2 = ac.q2(o,a)

        with torch.no_grad():
            a2, logp_a2 = ac.pi(o2)

            q1_pi_targ = ac_targ.q1(o2, a2)
            q2_pi_targ = ac_targ.q2(o2, a2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + gamma * (1 - d) * (q_pi_targ - alpha * logp_a2)

        loss_q1 = ((q1 - backup)**2).mean()
        loss_q2 = ((q2 - backup)**2).mean()
        loss_q = loss_q1 + loss_q2

        q_info = dict(Q1Vals=q1.detach().numpy(), Q2Vals=q2.detach().numpy())

        return loss_q, q_info

    def compute_loss_pi(data):
        o = data['obs']
        pi, logp_pi = ac.pi(o)
        q1_pi = ac.q1(o, pi)
        q2_pi = ac.q2(o, pi)
        q_pi = torch.min(q1_pi, q2_pi)

        loss_pi = (alpha * logp_pi - q_pi).mean()

        pi_info = dict(LogPi=logp_pi.detach().numpy())

        return loss_pi, pi_info

    pi_optimizer = Adam(ac.pi.parameters(), lr=lr)
    q_optimizer = Adam(q_params, lr=lr)

    logger.setup_pytorch_saver(ac)

    def update(data):
        q_optimizer.zero_grad()
        loss_q, q_info = compute_loss_q(data)
        loss_q.backward()
        q_optimizer.step()

        logger.store(LossQ=loss_q.item(), **q_info)

        for p in q_params:
            p.requires_grad = False

        pi_optimizer.zero_grad()
        loss_pi, pi_info = compute_loss_pi(data)
        loss_pi.backward()
        pi_optimizer.step()

        for p in q_params:
            p.requires_grad = True

        logger.store(LossPi=loss_pi.item(), **pi_info)

        with torch.no_grad():
            for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
                p_targ.data.mul_(polyak)
                p_targ.data.add_((1 - polyak) * p.data)

        update_stc(cfg, exp)

    def encode(ori_obs: list):
        views = [ random_crop_single(ori_obs, out=out) for _ in range(num_views) ]
        mvc_replay_buffer.add(views)
        mus, logvars = [], []
        for m_key, encoder in model.encoders.items(): # iterate over dictionary
            idx = int(m_key.strip().split('v')[-1])
            m = torch.Tensor(views[idx]).unsqueeze(0)
            _, _, mu, logvar = encoder(m)
            mus.append(mu)
            logvars.append(logvar)
        mus = torch.cat(mus)
        logvars = torch.cat(logvars)
        ivw_mu, ivw_logvar = model.ivw_fusion(mus, logvars)
        ivw_std = torch.sqrt(torch.exp(ivw_logvar))
        dist = Normal(ivw_mu, ivw_std)
        sample = dist.sample()
        return sample

    def get_action(o, deterministic=False):
        return ac.act(torch.as_tensor(o, dtype=torch.float32), 
                      deterministic)

    def test_agent():
        for j in range(num_test_episodes):
            o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0
            o = encode(o)
            while not(d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time 
                o, r, d, _ = test_env.step(get_action(o, True))
                o = encode(o)
                ep_ret += r
                ep_len += 1
            logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)

    # Prepare for interaction with environment
    total_steps = steps_per_epoch * epochs
    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0
    o = encode(o)

    for t in range(total_steps):
        if t > start_steps:
            a = get_action(o)
        else:
            a = env.action_space.sample()

        # Step the env
        o2, r, d, _ = env.step(a)
        o2 = encode(o2)
        ep_ret += r
        ep_len += 1

        d = False if ep_len==max_ep_len else d

        # Store experience to replay buffer
        replay_buffer.store(o, a, r, o2, d)
        o = o2
        if d or (ep_len == max_ep_len):
            logger.store(EpRet=ep_ret, EpLen=ep_len)
            o, ep_ret, ep_len = env.reset(), 0, 0
            o = encode(o)

        # Update handling
        if t >= update_after and t % update_every == 0:
            for j in range(update_every):
                batch = replay_buffer.sample_batch(batch_size)
                # TODO: 3-rd round, nan problem occurs
                update(data=batch)

        # End of epoch handling
        if (t+1) % steps_per_epoch == 0:
            epoch = (t+1) // steps_per_epoch

            # Save model
            if (epoch % save_freq == 0) or (epoch == epochs):
                logger.save_state({'env': env}, None)
                mvc_replay_buffer.save(save_dir=mvc_replay_buffer_dir)

            # Test the performance of the deterministic version of the agent.
            test_agent()

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('Time', time.time()-start_time)
            logger.dump_tabular()

