from data_modules import PlayDataset, Transform
from networks import *

import torch
from torch.distributions.kl import kl_divergence
from torch.optim import Adam

from torch.utils.data import DataLoader


class PlayLMP:
    def __init__(
            self,
            state_dim,
            act_dim,
            goal_dim,
            skill_dim,
            kl_coeff,
            kl_balance_coeff,
            learning_rate,
            dataset_cfg,
            device=torch.device('cuda')
        ):
        self.skill_recognition = SkillRecognition(
            state_dim=state_dim,
            act_dim=act_dim,
            skill_dim=skill_dim
        ).to(device)

        self.skill_proposal = SkillProposal(
            state_dim=state_dim,
            goal_dim=goal_dim,
            skill_dim=skill_dim
        ).to(device)

        self.skill_decoder = ManipulationSkillDecoder(
            state_dim=state_dim,
            act_dim=act_dim,
            skill_dim=skill_dim
        ).to(device)

        self.state_dim, self.skill_dim, self.act_dim, self.goal_dim = state_dim, skill_dim, act_dim, goal_dim

        self.kl_coeff = kl_coeff
        self.kl_balance_coeff = kl_balance_coeff
        self.learning_rate = learning_rate
        self.dataset_cfg = dataset_cfg
        self.transform = Transform()
        self.device = device

    def state_dict(self):
        state_dict = {
            'skill_recognition': self.skill_recognition.state_dict(),
            'skill_proposal': self.skill_proposal.state_dict(),
            'skill_decoder': self.skill_decoder.state_dict()
        }
        return state_dict

    def load_state_dict(self, state_dict):
        self.skill_recognition.load_state_dict(state_dict['skill_recognition'])
        self.skill_proposal.load_state_dict(state_dict['skill_proposal'])
        self.skill_decoder.load_state_dict(state_dict['skill_decoder'])

    def configure_optimizers(self):
        params = list(self.skill_recognition.parameters()) + list(self.skill_proposal.parameters()) \
                 + list(self.skill_decoder.parameters())
        return Adam(params, lr=self.learning_rate)

    def training_step(self, batch):
        batch = self.preprocess_batch(batch)
        skill_posterior = self.skill_recognition(batch['observations'], batch['actions'])
        skill_prior = self.skill_proposal(batch['observations'][:, 0], batch['goals'])
        kl_loss = self.kl_balance_coeff * kl_divergence(skill_posterior, skill_prior.detach()).mean() \
                  + (1. - self.kl_balance_coeff) * kl_divergence(skill_posterior.detach(), skill_prior).mean()

        skill = skill_posterior.rsample()
        log_prob, pi = self.skill_decoder.log_prob_with_sample(batch['observations'], batch['actions'], skill)
        decoder_loss = -log_prob.mean()
        with torch.no_grad():
            err = torch.norm(pi - batch['actions'], dim=-1).mean()

        loss = decoder_loss + self.kl_coeff * kl_loss
        logs = {
            'train/loss': loss,
            'train/decoder_loss': decoder_loss,
            'train/kl_loss': kl_loss,
            'train/err': err
        }
        return {'loss': loss, 'log': logs}

    @torch.no_grad()
    def validation_step(self, batch):
        batch = self.preprocess_batch(batch, eval=True)
        skill_posterior = self.skill_recognition(batch['observations'], batch['actions'])
        skill_prior = self.skill_proposal(batch['observations'][:, 0], batch['goals'])
        kl_loss = self.kl_balance_coeff * kl_divergence(skill_posterior, skill_prior.detach()).mean() \
                  + (1. - self.kl_balance_coeff) * kl_divergence(skill_posterior.detach(), skill_prior).mean()

        skill = skill_posterior.rsample()
        log_prob, pi = self.skill_decoder.log_prob_with_sample(batch['observations'], batch['actions'], skill)
        decoder_loss = -log_prob.mean()
        with torch.no_grad():
            err = torch.norm(pi - batch['actions'], dim=-1).mean()

        loss = decoder_loss + self.kl_coeff * kl_loss
        logs = {
            'val/loss': loss,
            'val/decoder_loss': decoder_loss,
            'val/kl_loss': kl_loss,
            'val/err': err
        }
        return {'loss': loss, 'log': logs}

    @torch.no_grad()
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        logs = {k: torch.stack([x['log'][k] for x in outputs]).mean() for k in outputs[0]['log'].keys()}
        return {'avg_val_loss': avg_loss, 'log': logs}

    def prepare_data(self):
        self.train_dataset = PlayDataset(
            data_container=self.dataset_cfg['train_data'],
            min_skill_length=self.dataset_cfg['min_skill_length'],
            max_skill_length=self.dataset_cfg['max_skill_length'],
            use_padding=self.dataset_cfg['use_padding']
        )
        self.val_dataset = PlayDataset(
            data_container=self.dataset_cfg['val_data'],
            min_skill_length=self.dataset_cfg['min_skill_length'],
            max_skill_length=self.dataset_cfg['max_skill_length'],
            use_padding=self.dataset_cfg['use_padding'],
            percentage=0.1
        )

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_dataset,
            batch_size=self.dataset_cfg['batch_size'],
            shuffle=True,
            drop_last=True,
            num_workers=self.dataset_cfg['num_workers'],
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.val_dataset,
            batch_size=self.dataset_cfg['batch_size'],
            shuffle=True,
            drop_last=True,
            num_workers=self.dataset_cfg['num_workers'],
            pin_memory=True
        )

    def preprocess_batch(self, batch, eval=False):
        for k, v in batch.items():
            batch[k] = v.to(self.device) if torch.is_tensor(v) else v
            if k == 'observations':
                batch[k] = torch.stack([self.transform(x, eval=eval) for x in batch[k].transpose(0, 1)], dim=1)
            if k == 'goals':
                batch[k] = self.transform(batch[k], eval=eval)
        return batch

    def preprocess_obs(self, observation):
        observation = torch.as_tensor(observation, device=self.device).unsqueeze(0)
        return self.transform(observation, eval=True)

    @torch.no_grad()
    def decode_skill(self, obs_list, skill):
        obs = np.array(obs_list)
        obs = self.preprocess_obs(obs)
        skill = torch.as_tensor(skill, dtype=torch.float32, device=self.device).unsqueeze(0)
        action = self.skill_decoder.sample(obs, skill)
        return action[0, -1].cpu().detach().numpy()