from data_modules import GCRLDataset, MilestoneDataLoader
from networks import *

import torch
import torch.nn.functional as F
from torch.optim import Adam

from torch.utils.data import DataLoader

from copy import deepcopy


class DTAMP:
    def __init__(
        self,
        state_dim,
        act_dim,
        goal_dim,
        visual_perception,
        n_milestones,
        n_critics,
        bc_epochs,
        rl_coeff,
        decoder_coeff,
        diffuser_coeff,
        predict_epsilon,
        diffuser_timesteps,
        td_condition,
        condition_guidance_w,
        target_td,
        hidden_size,
        n_hiddens,
        learning_rate,
        dataset_cfg,
        device=torch.device('cuda')
    ):
        self.actor = DeterministicActor(
            state_dim=state_dim,
            act_dim=act_dim,
            goal_dim=goal_dim // 2,
            visual_perception=visual_perception,
            hidden_size=hidden_size,
            n_hiddens=n_hiddens
        ).to(device)

        self.critic = Critics(
            state_dim=state_dim, 
            act_dim=act_dim, 
            goal_dim=goal_dim // 2,
            visual_perception=visual_perception,
            n_critics=n_critics, 
            hidden_size=hidden_size,
            n_hiddens=n_hiddens
        ).to(device)

        self.diffuser = GaussianDiffusion(
            horizon=n_milestones + 2,
            transition_dim=goal_dim,
            predict_epsilon=predict_epsilon,
            n_timesteps=diffuser_timesteps,
            td_condition=td_condition,
            condition_guidance_w=condition_guidance_w
        ).to(device)

        if visual_perception:
            from data_modules import Transform
            self.transform = Transform()
            self.obs_decoder = ObsDecoder(goal_dim).to(device)
            self.decoder_coeff = decoder_coeff

        self.state_dim = state_dim
        self.act_dim = act_dim
        self.goal_dim = goal_dim
        self.visual_perception = visual_perception
        self.bc_epochs = bc_epochs
        self.n_milestones = n_milestones
        self.n_critics = n_critics
        self.diffuser_coeff = diffuser_coeff
        self.learning_rate = learning_rate
        self.dataset_cfg = dataset_cfg
        self.device = device
        self.rl_coeff = rl_coeff
        self.target_td = target_td

    def state_dict(self):
        state_dict = {
            'actor': self.actor.state_dict(),
            'critic': self.critic.state_dict(),
            'diffuser': self.diffuser.state_dict()
        }
        if self.visual_perception:
            state_dict['obs_decoder'] = self.obs_decoder.state_dict()
        return state_dict

    def load_state_dict(self, state_dict):
        self.actor.load_state_dict(state_dict['actor'])
        self.critic.load_state_dict(state_dict['critic'])
        self.diffuser.load_state_dict(state_dict['diffuser'])
        if self.visual_perception:
            self.obs_decoder.load_state_dict(state_dict['obs_decoder'])

    def configure_optimizers(self):
        params = [
            {'params': self.actor.parameters()},
            {'params': self.critic.parameters()},
            {'params': self.diffuser.parameters()}
        ]
        if self.visual_perception:
            params += [{'params': self.obs_decoder.parameters()}]
        return Adam(params, lr=self.learning_rate)

    def sample_negative_goals(self, batch):
        if self.visual_perception:
            ng_critic = torch.zeros_like(batch['g_critic'])
            for i in range(len(batch['g_critic'])):
                candidates = torch.cat([batch['g_critic'][:i], batch['g_critic'][i + 1:]], dim=0)
                with torch.no_grad():
                    similarity = (candidates * batch['g_critic'][i:i + 1]).sum(-1)
                ng_critic[i, :] = candidates[torch.argmax(similarity)].clone()
            return ng_critic
        else:
            return torch.cat([batch['g_critic'][1:], batch['g_critic'][:1]], dim=0)

    def compute_actor_loss(self, batch, bc_epoch):
        pi = self.actor.get_pi_from_embeddings(batch['x_actor'], batch['g_actor'])
        critic = deepcopy(self.critic)
        q_pis = critic.get_value_from_embeddings(
            batch['x_critic'].detach(), pi, batch['g_critic'].detach()
        )
        min_q_pi = None
        for q_pi in q_pis:
            min_q_pi = q_pi if min_q_pi is None else torch.minimum(q_pi, min_q_pi)

        bc_loss = (pi - batch['actions']).pow(2).mean()
        if bc_epoch:
            actor_loss = bc_loss
        else:
            actor_loss = -self.rl_coeff * min_q_pi.mean() / min_q_pi.abs().mean().detach() + bc_loss

        return actor_loss, bc_loss, min_q_pi.mean()

    def compute_critic_loss(self, batch):
        q_acts = self.critic.get_value_from_embeddings(
            batch['x_critic'], batch['actions'], batch['g_critic']
        )
        q_negs = self.critic.get_value_from_embeddings(
            batch['x_critic'], batch['actions'], batch['ng_critic']
        )
        critic_loss = 0
        for q_act, q_neg in zip(q_acts, q_negs):
            critic_loss += F.binary_cross_entropy_with_logits(q_act, torch.ones_like(q_act)) \
                           + F.binary_cross_entropy_with_logits(q_neg, torch.zeros_like(q_neg))
        critic_loss = critic_loss / self.n_critics

        return critic_loss, q_act.mean(), q_neg.mean()

    def compute_reconstruction_loss(self, batch):
        g = torch.cat([batch['g_actor'], batch['g_critic']], dim=-1)
        obs_recon = self.obs_decoder(g)
        recon_loss = (obs_recon - batch['raw_observations']).pow(2).mean()
        return recon_loss

    def compute_diffuser_loss(self, batch):
        cond = {0: batch['g'][:, 0], -1: batch['g'][:, -1]}
        return self.diffuser.loss(batch['g'], cond, td=batch['intervals'])

    def training_step(self, batch):
        bc_epoch = batch['epoch'] < self.bc_epochs
        batch = self.preprocess_batch(batch)

        batch['x_actor'] = self.actor.perception(batch['observations'])
        batch['x_critic'] = self.critic.perception(batch['observations'])

        batch['g_actor'] = self.actor.encode(batch['goals'])
        batch['g_critic'] = self.critic.encode(batch['goals'])

        batch['ng_critic'] = self.sample_negative_goals(batch)

        actor_loss, bc_loss, q_pi = self.compute_actor_loss(batch, bc_epoch)
        critic_loss, q_pos, q_neg = self.compute_critic_loss(batch)

        milestone_batch = self.train_milestone_dataloader.sample()
        milestone_batch = self.preprocess_batch(milestone_batch)

        milestone_batch['g'] = self.encode(milestone_batch['observations'])
        diffuser_loss = self.compute_diffuser_loss(milestone_batch)

        loss = actor_loss + critic_loss + self.diffuser_coeff * diffuser_loss
        
        if self.visual_perception:
            decoder_loss = self.compute_reconstruction_loss(batch)
            loss += self.decoder_coeff * decoder_loss

        field = 'bc_train' if bc_epoch else 'rl_train'
        logs = {
            f'{field}/loss': loss,
            f'{field}/actor_loss': actor_loss,
            f'{field}/bc_loss': bc_loss,
            f'{field}/critic_loss': critic_loss,
            f'{field}/diffuser_loss': diffuser_loss
        }
        if self.visual_perception:
            logs[f'{field}/decoder_loss'] = decoder_loss

        return {'loss': loss, 'log': logs}

    @torch.no_grad()
    def validation_step(self, batch):
        bc_epoch = batch['epoch'] < self.bc_epochs
        batch = self.preprocess_batch(batch)

        batch['x_actor'] = self.actor.perception(batch['observations'])
        batch['x_critic'] = self.critic.perception(batch['observations'])

        batch['g_actor'] = self.actor.encode(batch['goals'])
        batch['g_critic'] = self.critic.encode(batch['goals'])

        batch['ng_critic'] = self.sample_negative_goals(batch)

        actor_loss, bc_loss, q_pi = self.compute_actor_loss(batch, bc_epoch)
        critic_loss, q_pos, q_neg = self.compute_critic_loss(batch)

        milestone_batch = self.val_milestone_dataloader.sample()
        milestone_batch = self.preprocess_batch(milestone_batch)

        milestone_batch['g'] = self.encode(milestone_batch['observations'])
        diffuser_loss = self.compute_diffuser_loss(milestone_batch)

        loss = actor_loss + critic_loss + self.diffuser_coeff * diffuser_loss

        if self.visual_perception:
            decoder_loss = self.compute_reconstruction_loss(batch)
            loss += self.decoder_coeff * decoder_loss

        field = 'bc_val' if bc_epoch else 'rl_val'
        logs = {
            f'{field}/loss': loss,
            f'{field}/actor_loss': actor_loss,
            f'{field}/bc_loss': bc_loss,
            f'{field}/critic_loss': critic_loss,
            f'{field}/diffuser_loss': diffuser_loss
        }
        if self.visual_perception:
            logs[f'{field}/decoder_loss'] = decoder_loss

        return {'loss': loss, 'log': logs}

    @torch.no_grad()
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        logs = {k: torch.stack([x['log'][k] for x in outputs]).mean() for k in outputs[0]['log'].keys()}
        return {'avg_val_loss': avg_loss, 'log': logs}

    def prepare_data(self):
        self.train_dataset = GCRLDataset(
            data_container=self.dataset_cfg['train_data'],
            min_dt=self.dataset_cfg['min_dt'],
            max_dt=self.dataset_cfg['max_dt'],
            use_skill=self.dataset_cfg['use_skill'],
            skill_length=self.dataset_cfg['skill_length']
        )
        self.val_dataset = GCRLDataset(
            data_container=self.dataset_cfg['val_data'],
            min_dt=self.dataset_cfg['min_dt'],
            max_dt=self.dataset_cfg['max_dt'],
            use_skill=self.dataset_cfg['use_skill'],
            skill_length=self.dataset_cfg['skill_length'],
            percentage=0.1
        )
        self.train_milestone_dataloader = MilestoneDataLoader(
            data_container=self.dataset_cfg['train_data'],
            max_interval=self.dataset_cfg['max_interval'],
            horizon=self.n_milestones + 2,
            batch_size=self.dataset_cfg['diffuser_batch_size']
        )
        self.val_milestone_dataloader = MilestoneDataLoader(
            data_container=self.dataset_cfg['val_data'],
            max_interval=self.dataset_cfg['max_interval'],
            horizon=self.n_milestones + 2,
            batch_size=self.dataset_cfg['diffuser_batch_size']
        )

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_dataset,
            batch_size=self.dataset_cfg['rl_batch_size'],
            shuffle=True,
            drop_last=True,
            num_workers=self.dataset_cfg['num_workers'],
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.val_dataset,
            batch_size=self.dataset_cfg['rl_batch_size'],
            shuffle=True,
            drop_last=True,
            num_workers=self.dataset_cfg['num_workers'],
            pin_memory=True
        )

    def preprocess_batch(self, batch, eval=False):
        for k, v in batch.items():
            batch[k] = v.to(self.device) if torch.is_tensor(v) else v
        if self.visual_perception:
            if len(batch['observations'].shape) == 5:
                """
                observations: [batch_size, horizon, channels, height, width]
                """
                batch['observations'] = torch.stack(
                    [self.transform(x, eval=eval) for x in batch['observations'].transpose(0, 1)], dim=1
                )
            else:
                """
                observations: [batch_size, channels, height, width]
                goals: [batch_size, channels, height, width]
                """
                batch['observations'] = self.transform(batch['observations'], eval=eval)
                batch['raw_observations'] = self.transform(batch['goals'], eval=True)
                batch['goals'] = self.transform(batch['goals'], eval=eval)
        return batch

    def preprocess_obs(self, observation):
        if self.visual_perception:
            observation = torch.as_tensor(observation, dtype=torch.uint8, device=self.device).unsqueeze(0)
            observation = self.transform(observation, eval=True)
        else:
            observation = torch.as_tensor(observation, dtype=torch.float32, device=self.device).unsqueeze(0)
        return observation

    def encode(self, obs):
        g_actor = self.actor.encode(obs)
        g_critic = self.critic.encode(obs)
        return torch.cat([g_actor, g_critic], dim=-1)

    def split_and_normalize(self, g):
        g_actor, g_critic = torch.split(g, self.goal_dim // 2, dim=-1)
        g_actor = F.normalize(g_actor, p=2.0, dim=-1)
        g_critic = F.normalize(g_critic, p=2.0, dim=-1)
        return torch.cat([g_actor, g_critic], dim=-1)

    @torch.no_grad()
    def planning(self, obs, goal):
        obs = self.preprocess_obs(obs)
        goal = self.preprocess_obs(goal)
        g_0 = self.encode(obs)
        g_goal = self.encode(goal)
        cond = {0: g_0, -1: g_goal}
        target_td = torch.full([1, 1], self.target_td, dtype=torch.float32, device=self.device)
        milestones = self.diffuser.conditional_sample(cond, td=target_td)
        milestones = self.split_and_normalize(milestones)
        return milestones.squeeze()

    @torch.no_grad()
    def get_action(self, obs, enc_goal):
        obs = self.preprocess_obs(obs)
        emb_state = self.actor.perception(obs)
        pi = self.actor.get_pi_from_embeddings(emb_state, enc_goal[:, :self.goal_dim // 2])
        return pi.squeeze().cpu().detach().numpy()

    @torch.no_grad()
    def compute_distance(self, obs, enc_goal):
        obs = self.preprocess_obs(obs)
        enc_obs = self.encode(obs)
        distance = (enc_obs - enc_goal).pow(2).sum()
        return distance.item()