import copy

import numpy as np
import torch as tr
import torch.nn.functional as F
from torch.distributions import kl_divergence

from networks import Policy, DoubleQFunc
from utils import ReplayPool, KL_gauss, DefaultReplayPool

device = tr.device("cuda:0" if tr.cuda.is_available() else "cpu")

class MDLCSeq_Agent:

    def __init__(
        self,
        seed,
        state_dim,
        action_dim,
        lr=3e-4,
        gamma=0.99,
        tau=5e-3,
        batchsize=256,
        hidden_size=256,
        update_interval=1,
        buffer_size=int(1e6),
        target_entropy=None,
        beta=0.05,
        control_vdo=False,
        default_vdo=True,
        norm_vdo=False,
        learned_asymmetry=False,
        num_tasks=1,
        **kwargs
        ):
        self.gamma = gamma
        self.tau = tau
        self.target_entropy = target_entropy if target_entropy else -action_dim
        self.batchsize = batchsize
        self.update_interval = update_interval

        tr.manual_seed(seed)
        self.seed = seed
        self.device = device

        # aka critic
        self.q_funcs = DoubleQFunc(state_dim, action_dim, hidden_size=hidden_size).to(device)
        self.target_q_funcs = copy.deepcopy(self.q_funcs)
        self.target_q_funcs.eval()
        for p in self.target_q_funcs.parameters():
            p.requires_grad = False

        # aka actor
        self.control_policy = Policy(
            state_dim,action_dim, hidden_size=hidden_size, vdo=control_vdo, norm_vdo=norm_vdo
            ).to(device)
        self.default_policy = Policy(
            state_dim, action_dim, hidden_size=hidden_size, vdo=default_vdo, norm_vdo=norm_vdo, learned_asymmetry=learned_asymmetry
            ).to(device)
        self.control_vdo, self.default_vdo, self.norm_vdo = control_vdo, default_vdo, norm_vdo
        self.state_dim, self.action_dim, self.hidden_size = state_dim, action_dim, hidden_size

        # aka temperature
        self.log_alpha = tr.zeros(1, requires_grad=True, device=device)
        self.num_tasks = num_tasks
        if num_tasks > 1:
            self.log_alpha = tr.tensor([0.0] * self.num_tasks, device=device).requires_grad_()
        self.temp_optimizer = tr.optim.Adam([self.log_alpha], lr=lr/4)
        self.beta = beta

        self.q_optimizer = tr.optim.Adam(self.q_funcs.parameters(), lr=lr)
        self.control_policy_optimizer = tr.optim.Adam(self.control_policy.parameters(), lr=lr)
        self.default_policy_optimizer = tr.optim.Adam(self.default_policy.parameters(), lr=lr)
        self.lr = lr
        
        self.buffer_size = buffer_size
        self.replay_pool = ReplayPool(action_dim=action_dim, state_dim=state_dim, capacity=buffer_size)
        self.default_replay_pool = DefaultReplayPool(action_dim=action_dim, state_dim=state_dim, capacity=buffer_size)
    
    def reset_control(self):
        # reset everything but the default policy
        self.q_funcs = DoubleQFunc(self.state_dim, self.action_dim, hidden_size=self.hidden_size).to(device)
        self.target_q_funcs = copy.deepcopy(self.q_funcs)
        self.target_q_funcs.eval()
        for p in self.target_q_funcs.parameters():
            p.requires_grad = False

        self.control_policy = Policy(
            self.state_dim,
            self.action_dim,
            hidden_size=self.hidden_size,
            vdo=self.control_vdo,
            norm_vdo=self.norm_vdo
            ).to(device)

        self.log_alpha = tr.zeros(1, requires_grad=True, device=device)
        if self.num_tasks > 1:
            self.log_alpha = tr.tensor([0.0] * self.num_tasks, device=device).requires_grad_()
        self.temp_optimizer = tr.optim.Adam([self.log_alpha], lr=self.lr/4)

        self.q_optimizer = tr.optim.Adam(self.q_funcs.parameters(), lr=self.lr)
        self.control_policy_optimizer = tr.optim.Adam(self.control_policy.parameters(), lr=self.lr)
        
        self.replay_pool = ReplayPool(action_dim=self.action_dim, state_dim=self.state_dim, capacity=self.buffer_size)
    
    
    def get_action(self, state, state_filter=None, deterministic=False, get_dist=False):
        if state_filter:
            state = state_filter(state)
        with tr.no_grad():
            action, _, mean, std, _ = self.control_policy(tr.Tensor(state).view(1,-1).to(device), get_dist=True)
        if deterministic:
            return mean.squeeze().cpu().numpy()
        elif get_dist:
            return np.atleast_1d(action.squeeze().cpu().numpy()), mean, std
        return np.atleast_1d(action.squeeze().cpu().numpy())

    def update_target(self):
        """moving average update of target networks"""
        with tr.no_grad():
            for target_q_param, q_param in zip(self.target_q_funcs.parameters(), self.q_funcs.parameters()):
                target_q_param.data.copy_(self.tau * q_param.data + (1.0 - self.tau) * target_q_param.data)

    def update_q_functions(
        self,
        state_batch,
        action_batch,
        reward_batch,
        nextstate_batch,
        done_batch,
        use_kl=True):
        with tr.no_grad():
            nextaction_batch, control_logprobs_batch, control_mean_batch, control_std_batch, control_dist_batch = self.control_policy(
                nextstate_batch, get_logprob=True, get_dist=True)
            _, default_logprobs_batch, default_mean_batch, default_std_batch, default_dist_batch = self.default_policy(
                nextstate_batch, get_logprob=True, get_dist=True)
            q_t1, q_t2 = self.target_q_funcs(nextstate_batch, nextaction_batch)
            # take min to mitigate positive bias in q-function training
            q_target = tr.min(q_t1, q_t2)
            # compute kl/entropy regularization
            kl_batch = kl_divergence(control_dist_batch, default_dist_batch)
            kl_batch = tr.mean(kl_batch, dim=-1, keepdims=True)
            reg_batch = kl_batch if use_kl else control_logprobs_batch
            # get alpha(s)
            # if multitask, select temperature 
            if self.num_tasks > 1:
                task_id_batch_one_hot = state_batch[:, -self.num_tasks:] # batch_size x num_tasks
                task_id_batch = tr.argmax(task_id_batch_one_hot, axis=1) # batch_size
                alpha = self.alpha
                alpha_batch = tr.tensor([alpha[id] for id in task_id_batch]).unsqueeze(dim=-1).to(device) # select the alpha for each example (probably slow...)
            else:
                alpha_batch = self.alpha
            value_target = reward_batch + (1.0 - done_batch) * self.gamma * (q_target - alpha_batch * reg_batch)
        
        q_1, q_2 = self.q_funcs(state_batch, action_batch)
        loss_1 = F.mse_loss(q_1, value_target)
        loss_2 = F.mse_loss(q_2, value_target)
        return loss_1, loss_2

    def update_control_policy_and_temp(self, state_batch):
        action_batch, control_logprobs_batch, _ = self.control_policy(
            state_batch, get_logprob=True, get_dist=False)
        q_b1, q_b2 = self.q_funcs(state_batch, action_batch)
        qval_batch = tr.min(q_b1, q_b2)
        control_loss = (self.alpha * control_logprobs_batch - qval_batch).mean()
        temp_loss = -self.alpha * (control_logprobs_batch.detach() + self.target_entropy).mean()
        if self.num_tasks > 1:
            temp_loss = temp_loss.mean()

        return control_loss, temp_loss

    def update_default_policy(self, state_batch, control_mean_batch, control_std_batch, beta=1.0):

        default_loss, kl_batch, kl_vdo = tr.tensor(0).float(), tr.tensor(0).float(), tr.tensor(0).float()

        _, _, default_mean_batch, default_std_batch, _ = self.default_policy(
            state_batch, get_logprob=True, get_dist=True)
        kl_batch = KL_gauss(
            control_mean_batch.detach(), control_std_batch.detach(),
            default_mean_batch, default_std_batch).mean()
        kl_vdo = self.default_policy.network.kl_vdo()
        default_loss = kl_batch + beta * kl_vdo

        return default_loss, kl_batch, kl_vdo

    def optimize(
        self, n_updates, state_filter=None, update_default=False, beta=0.01, use_kl=True, **kwargs
        ):
        q1_loss, q2_loss, pi_loss, a_loss = 0, 0, 0, 0
        if self.num_tasks > 1:
            a_loss = np.zeros(self.num_tasks)

        for i in range(n_updates):
            samples = self.replay_pool.sample(self.batchsize)

            if state_filter:
                state_batch = tr.FloatTensor(state_filter(samples.state)).to(device)
                nextstate_batch = tr.FloatTensor(state_filter(samples.nextstate)).to(device)
            else:
                state_batch = tr.FloatTensor(np.array(samples.state)).to(device)
                nextstate_batch = tr.FloatTensor(np.array(samples.nextstate)).to(device)
            action_batch = tr.FloatTensor(np.array(samples.action)).to(device)
            reward_batch = tr.FloatTensor(np.array(samples.reward)).to(device).unsqueeze(1)
            done_batch = tr.FloatTensor(np.array(samples.real_done)).to(device).unsqueeze(1)
            
            # update q-funcs
            q1_loss_step, q2_loss_step = self.update_q_functions(
                state_batch, action_batch, reward_batch, nextstate_batch, done_batch, use_kl=use_kl
                )
            q_loss_step = q1_loss_step + q2_loss_step
            self.q_optimizer.zero_grad()
            q_loss_step.backward()
            self.q_optimizer.step()

            # update policy and temperature parameter
            for p in self.q_funcs.parameters():
                p.requires_grad = False
            pi_loss_step, a_loss_step = self.update_control_policy_and_temp(
                state_batch
                )
            # update control policy
            self.control_policy_optimizer.zero_grad()
            pi_loss_step.backward(retain_graph=True)
            # call control policy step below
            # update temp
            self.temp_optimizer.zero_grad()
            a_loss_step.backward()
            self.temp_optimizer.step()
            
            self.control_policy_optimizer.step()    

            for p in self.q_funcs.parameters():
                p.requires_grad = True


            q1_loss += q1_loss_step.detach().item()
            q2_loss += q2_loss_step.detach().item()
            pi_loss += pi_loss_step.detach().item()
            if self.num_tasks > 1:
                a_loss += a_loss_step.detach().cpu().numpy()
            else:
                a_loss += a_loss_step.detach().item()
            if i % self.update_interval == 0:
                self.update_target()
        return q1_loss, q2_loss, pi_loss, a_loss

    def optimize_default_policy(
        self,
        n_updates,
        state_filter=None,
        beta_start=0.5,
        beta_max=1.0,
        beta_warmup=0.2,
        writer=None,
        log_freq=10,
        task_name=None,
        **kwargs
        ):
        beta_update_start = int(beta_start * n_updates)
        n_beta_ramp = int(beta_warmup * n_updates)
        beta_schedule = [0.0] * beta_update_start + list(np.linspace(0.0, beta_max, num=n_beta_ramp))

        for i in range(n_updates):
            samples = self.default_replay_pool.sample(self.batchsize)
            # set VDO KL coefficient
            beta = beta_max if i >= len(beta_schedule) else beta_schedule[i]

            if state_filter:
                state_batch = tr.FloatTensor(state_filter(samples.state)).to(device)
            else:
                state_batch = tr.FloatTensor(np.array(samples.state)).to(device)
            control_mean_batch = tr.FloatTensor(np.array(samples.mean)).to(device)
            control_std_batch = tr.FloatTensor(np.array(samples.std)).to(device)


            # update default policy 
            for p in self.q_funcs.parameters():
                p.requires_grad = False
            pi0_loss_step, policy_kl_step, vdo_kl_step = self.update_default_policy(
                state_batch, control_mean_batch, control_std_batch, beta=beta
                )
            
            self.default_policy_optimizer.zero_grad()
            pi0_loss_step.backward()
            self.default_policy_optimizer.step()

            for p in self.q_funcs.parameters():
                p.requires_grad = True

            prefix = "" if task_name is None else task_name 
            if i % log_freq == 0:
                writer.add_scalar(f"{prefix}-DefaultLoss/default_policy", pi0_loss_step.detach().item(), i)
                writer.add_scalar(f"{prefix}-DefaultLoss/policy_kl", policy_kl_step.detach().item(), i)
                default_sparsity = self.default_policy.network.sparsity()
                writer.add_scalar(f"{prefix}-DefaultValues/default_sparsity", default_sparsity, i)
                writer.add_scalar(f"{prefix}-DefaultLoss/default_vdo_kl", vdo_kl_step.detach().item(), i)


    @property
    def alpha(self):
        return self.log_alpha.exp()

