import time
import torch
import numpy as np

from decision_transformer.training.trainer import Trainer


def copy_param(target_link, source_link):
    """Copy parameters of a link to another link."""
    target_link.load_state_dict(source_link.state_dict())


def soft_copy_param(target_link, source_link, tau=0.05):
    """Soft-copy parameters of a link to another link."""
    target_dict = target_link.state_dict()
    source_dict = source_link.state_dict()
    for k, target_value in target_dict.items():
        source_value = source_dict[k]
        if source_value.dtype in [torch.float32, torch.float64, torch.float16]:
            assert target_value.shape == source_value.shape
            target_value.mul_(1 - tau)
            target_value.add_(tau * source_value)
        else:
            # Scalar type
            # Some modules such as BN has scalar value `num_batches_tracked`
            target_dict[k] = source_value


class SequenceTrainer(Trainer):

    def train_step(self):
        states, actions, rewards, dones, rtg, timesteps, attention_mask = self.get_batch(self.batch_size)
        action_target = torch.clone(actions)

        state_preds, action_preds, reward_preds = self.model.forward(
            states, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask,
        )

        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]

        loss = self.loss_fn(
            None, action_preds, None,
            None, action_target, None,
        )

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25)
        self.optimizer.step()

        with torch.no_grad():
            self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item()

        return loss.detach().cpu().item()

    def train_only_iteration(self, num_steps, iter_num=0, print_logs=False):
        train_losses = []
        logs = dict()

        train_start = time.time()

        self.model.train()
        for _ in range(num_steps):
            train_loss = self.train_step()
            train_losses.append(train_loss)
            if self.scheduler is not None:
                self.scheduler.step()

        logs['time/training'] = time.time() - train_start

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs


class DistributionalSequenceTrainer(Trainer):

    def train_step(self):
        states, actions, rewards, dones, rtg, timesteps, attention_mask, dists = self.get_batch(self.batch_size)
        action_target = torch.clone(actions)

        state_preds, action_preds, reward_preds = self.model.forward(
            states, actions, rewards, rtg[:,:-1], timesteps, dists, attention_mask=attention_mask,
        )

        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]

        loss = self.loss_fn(
            None, action_preds, None,
            None, action_target, None,
        )

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25)
        self.optimizer.step()

        with torch.no_grad():
            self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item()

        return loss.detach().cpu().item()

    def train_iteration(self, num_steps, iter_num=0, print_logs=False):
        train_losses = []
        logs = dict()

        train_start = time.time()

        self.model.train()
        for _ in range(num_steps):
            train_loss = self.train_step()
            train_losses.append(train_loss)
            if self.scheduler is not None:
                self.scheduler.step()

        logs['time/training'] = time.time() - train_start

        eval_start = time.time()

        self.model.eval()
        for eval_fn in self.eval_fns:
            outputs = eval_fn(self.model, iter_num-1)
            for k, v in outputs.items():
                logs[f'evaluation/{k}'] = v

        logs['time/total'] = time.time() - self.start_time
        logs['time/evaluation'] = time.time() - eval_start
        logs['training/train_loss_mean'] = np.mean(train_losses)
        logs['training/train_loss_std'] = np.std(train_losses)

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs

    def train_only_iteration(self, num_steps, iter_num=0, print_logs=False):
        train_losses = []
        logs = dict()

        train_start = time.time()

        self.model.train()
        for _ in range(num_steps):
            train_loss = self.train_step()
            train_losses.append(train_loss)
            if self.scheduler is not None:
                self.scheduler.step()

        logs['time/training'] = time.time() - train_start

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs


class UnsupervisedTrainer:

    def __init__(
        self,
        model,
        optimizer,
        batch_size,
        get_batch,
        loss_fn,
        scheduler=None,
        eval_fns=None,
        ae_training='ae_dt',
        ae_coef=1.0,
        add_rtg=False,
        ae_optimizer=None,
        ):
        self.model = model
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.get_batch = get_batch
        self.loss_fn = loss_fn
        self.scheduler = scheduler
        self.eval_fns = [] if eval_fns is None else eval_fns
        self.diagnostics = dict()

        self.start_time = time.time()

        # for representation learning
        self.add_rtg = add_rtg
        self.ae_coef = ae_coef
        self.ae_training = ae_training
        self.ae_optimizer = ae_optimizer

    def train_iteration(self, num_steps, iter_num=0, print_logs=False):

        train_losses = []
        logs = dict()

        train_start = time.time()

        self.model.train()
        for _ in range(num_steps):
            train_loss = self.train_step()
            train_losses.append(train_loss)
            if self.scheduler is not None:
                self.scheduler.step()

        logs['time/training'] = time.time() - train_start

        eval_start = time.time()

        self.model.eval()
        for eval_fn in self.eval_fns:
            outputs = eval_fn(self.model)
            for k, v in outputs.items():
                logs[f'evaluation/{k}'] = v

        logs['time/total'] = time.time() - self.start_time
        logs['time/evaluation'] = time.time() - eval_start
        logs['training/train_loss_mean'] = np.mean(train_losses)
        logs['training/train_loss_std'] = np.std(train_losses)

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs

    def train_step(self):
        states, actions, rewards, dones, rtg, timesteps, attention_mask = self.get_batch(self.batch_size)
        action_target = torch.clone(actions)

        if self.ae_training == 'ae':
            assert self.ae_optimizer is not None
            self.model.encoder.train().requires_grad_(True)
            self.model.decoder.train().requires_grad_(True)
            z = self.model.encoder(states)
            s_hat = self.model.decoder(z)
            assert s_hat.shape == states.shape
            ae_loss = torch.nn.functional.mse_loss(s_hat, states)
            self.ae_optimizer.zero_grad()
            ae_loss.backward()
            # torch.nn.utils.clip_grad_norm_(self.model.encoder.parameters(), .25)
            self.ae_optimizer.step()
            self.model.encoder.eval().requires_grad_(False)
            self.model.decoder.eval().requires_grad_(False)

        state_preds, action_preds, reward_preds = self.model.forward(
            states, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask,
        )

        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]

        loss = self.loss_fn(
            None, action_preds, None,
            None, action_target, None,
        )
        # auto-encoder
        if self.ae_training == 'ae_dt':
            z = self.model.encoder(states)
            s_hat = self.model.decoder(z)
            assert s_hat.shape == states.shape
            ae_loss = torch.nn.functional.mse_loss(s_hat, states)
            loss += self.ae_coef * ae_loss
        elif self.ae_training in ('dt', 'pre', 'fine'):
            ae_loss = 0.0

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25)
        self.optimizer.step()

        with torch.no_grad():
            self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item()
            if self.ae_training in ('ae_dt', 'ae'):
                ae_loss = ae_loss.detach().cpu().item()
            self.diagnostics['training/ae_loss'] = ae_loss

        return loss.detach().cpu().item()

    def train_only_iteration(self, num_steps, iter_num=0, print_logs=False):
        train_losses = []
        logs = dict()

        train_start = time.time()

        self.model.train()
        if self.ae_training == 'pre':
            self.model.encoder.eval().requires_grad_(False)
        for _ in range(num_steps):
            train_loss = self.train_step()
            train_losses.append(train_loss)
            if self.scheduler is not None:
                self.scheduler.step()

        logs['time/training'] = time.time() - train_start

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs


class CPCTrainer:

    def __init__(
        self,
        model,
        optimizer,
        batch_size,
        get_batch,
        loss_fn,
        scheduler=None,
        eval_fns=None,
        ae_training='ae_dt',
        ae_coef=1.0,
        add_rtg=False,
        ae_optimizer=None,
        ):
        self.model = model
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.get_batch = get_batch
        self.loss_fn = loss_fn
        self.scheduler = scheduler
        self.eval_fns = [] if eval_fns is None else eval_fns
        self.diagnostics = dict()

        self.start_time = time.time()

        # for representation learning
        self.add_rtg = add_rtg
        self.ae_coef = ae_coef
        self.ae_training = ae_training
        self.ae_optimizer = ae_optimizer

    def train_iteration(self, num_steps, iter_num=0, print_logs=False):

        train_losses = []
        logs = dict()

        train_start = time.time()

        self.model.train()
        for _ in range(num_steps):
            train_loss = self.train_step()
            train_losses.append(train_loss)
            if self.scheduler is not None:
                self.scheduler.step()

        logs['time/training'] = time.time() - train_start

        eval_start = time.time()

        self.model.eval()
        for eval_fn in self.eval_fns:
            outputs = eval_fn(self.model)
            for k, v in outputs.items():
                logs[f'evaluation/{k}'] = v

        logs['time/total'] = time.time() - self.start_time
        logs['time/evaluation'] = time.time() - eval_start
        logs['training/train_loss_mean'] = np.mean(train_losses)
        logs['training/train_loss_std'] = np.std(train_losses)

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs

    def train_step(self):
        states, actions, rewards, dones, rtg, timesteps, attention_mask = self.get_batch(self.batch_size)
        action_target = torch.clone(actions)

        if self.ae_training == 'ae':
            assert self.ae_optimizer is not None
            self.model.encoder.train().requires_grad_(True)
            self.model.momentum_encoder.eval().requires_grad_(False)
            z = self.model.encoder(states).reshape(self.batch_size, -1)
            eps = torch.normal(mean=0.0, std=0.1, size=states.shape)
            z_hat = self.model.momentum_encoder(states+eps).reshape(self.batch_size, -1)
            z_hat = z_hat.detach()
            proj_k = self.model.W(z_hat.T) # bilinear product
            logits = torch.matmul(z, proj_k)  # B x B
            # subtract max from logits for stability
            logits = logits - torch.max(logits, axis=1)[0]
            labels = torch.arange(logits.shape[0], device=logits.device)
            ae_loss = torch.nn.CrossEntropyLoss()(logits, labels)
            self.ae_optimizer.zero_grad()
            ae_loss.backward()
            soft_copy_param(self.model.momentum_encoder, self.model.encoder, tau=0.05)
            # torch.nn.utils.clip_grad_norm_(self.model.encoder.parameters(), .25)
            self.ae_optimizer.step()
            self.model.encoder.eval().requires_grad_(False)
            self.model.momentum_encoder.eval().requires_grad_(False)

        state_preds, action_preds, reward_preds = self.model.forward(
            states, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask,
        )

        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]

        loss = self.loss_fn(
            None, action_preds, None,
            None, action_target, None,
        )
        # auto-encoder
        if self.ae_training == 'ae_dt':
            self.model.encoder.train().requires_grad_(True)
            self.model.momentum_encoder.eval().requires_grad_(False)
            z = self.model.encoder(states).reshape(self.batch_size, -1)
            eps = torch.normal(mean=0.0, std=0.1, size=states.shape)
            z_hat = self.model.momentum_encoder(states+eps).reshape(self.batch_size, -1)
            z_hat = z_hat.detach()
            proj_k = self.model.W(z_hat.T) # bilinear product
            logits = torch.matmul(z, proj_k)  # B x B
            # subtract max from logits for stability
            logits = logits - torch.max(logits, axis=1)[0]
            labels = torch.arange(logits.shape[0], device=logits.device)
            ae_loss = torch.nn.CrossEntropyLoss()(logits, labels)
            loss += self.ae_coef * ae_loss
        elif self.ae_training in ('dt', 'pre', 'fine'):
            ae_loss = 0.0

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25)
        self.optimizer.step()
        soft_copy_param(self.model.momentum_encoder, self.model.encoder, tau=0.05)

        with torch.no_grad():
            self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item()
            if self.ae_training in ('ae_dt', 'ae'):
                ae_loss = ae_loss.detach().cpu().item()
            self.diagnostics['training/ae_loss'] = ae_loss

        return loss.detach().cpu().item()

    def train_only_iteration(self, num_steps, iter_num=0, print_logs=False):
        train_losses = []
        logs = dict()

        train_start = time.time()

        self.model.train()
        if self.ae_training == 'pre':
            self.model.encoder.eval().requires_grad_(False)
        else:
            self.model.encoder.train().requires_grad_(True)
        for _ in range(num_steps):
            train_loss = self.train_step()
            train_losses.append(train_loss)
            if self.scheduler is not None:
                self.scheduler.step()

        logs['time/training'] = time.time() - train_start

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs
