import math
import torch
import numpy as np
from torch.utils.data.dataloader import DataLoader
import wandb
import warnings
from mage.datasets import Batch

from .timer import Timer

def to(xs, device):
    return [x.to(device) for x in xs]

class VQTrainer:

    def __init__(self, config):
        self.config = config
        self.device = config.device

        self.n_epochs = 0
        self.n_tokens = 0 # counter used for learning rate decay
        self.optimizer = None

    def get_optimizer(self, model):
        if self.optimizer is None:
            print(f'[ utils/training ] Making optimizer at epoch {self.n_epochs}')
            self.optimizer = model.configure_optimizers(self.config)
        return self.optimizer

    def train(self, model, dataset, n_epochs=1, log_freq=100):

        config = self.config
        optimizer = self.get_optimizer(model)
        model.train(True)

        loader = DataLoader(dataset, shuffle=True, pin_memory=True,
                            batch_size=config.batch_size,
                            num_workers=config.num_workers)

        for _ in range(n_epochs):

            losses = []
            timer = Timer()
            for it, batch_numpy in enumerate(loader):
                batch = to(batch_numpy, self.device)

                y = batch[1]
                self.n_tokens += np.prod(y.shape)
                if self.n_tokens < config.warmup_tokens:
                    lr_mult = float(self.n_tokens) / float(max(1, config.warmup_tokens))
                else:
                    progress = float(self.n_tokens - config.warmup_tokens) / float(
                        max(1, config.final_tokens - config.warmup_tokens))
                    lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))

                if config.lr_decay:
                    lr = config.learning_rate * lr_mult
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                else:
                    lr = config.learning_rate
                
                with torch.set_grad_enabled(True):
                    *_, recon_loss, vq_loss, action_loss, codebook_usage = model(*batch)
                    loss = recon_loss.mean() + vq_loss + action_loss
                    losses.append(loss.item())

                model.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
                optimizer.step()

                if it % log_freq == 0:
                    if dataset.test_portion == 0:
                        summary = dict(recontruction_loss=recon_loss.item(),
                                       mean_vq_loss=vq_loss.item(),
                                       action_loss = action_loss.item(), 
                                       codebook_usage = codebook_usage,
                                       lr=lr,
                                       lr_mulr=lr_mult,
                                       )
                        print(
                            f'[ utils/training ] epoch {self.n_epochs} [ {it:4d} / {len(loader):4d} ]',
                            f' recon_loss loss {recon_loss.item():.5f} |'
                            f' vq_loss {vq_loss.item():.5f} |'
                            f' action_loss {action_loss.item():.5f} |'
                            f' codebook_usage {codebook_usage:.5f} |'
                            f' lr {lr:.3e} | lr_mult: {lr_mult:.4f} |'
                            f' t: {timer():.2f}')
                    
                if dataset.test_portion >= 0:
                    torch.cuda.empty_cache()
            self.n_epochs += 1

from .evaluate import evaluator
class PriorTrainer:
    def __init__(self, config):
        self.config = config
        self.device = config.device
        self.n_epochs = 0
        self.n_tokens = 0 
        self.optimizer = None
        self.use_action = config.use_action

    def get_optimizer(self, model):
        if self.optimizer is None:
            print(f'[ utils/training ] Making optimizer at epoch {self.n_epochs}')
            self.optimizer = model.configure_optimizers(self.config)
        return self.optimizer
    
    def get_temp(self, step, decay_rate=0.99995, init_temp=2.0, final_temp=0.1):
        return max(final_temp, init_temp * (decay_rate ** step))

    def train(self, representation, model, dataset, n_epochs=1, log_freq=100, temp=0.5, decay_rate=0.99995):

        config = self.config
        optimizer = self.get_optimizer(model)
        representation.train(False)
        model.train(True)

        loader = DataLoader(dataset, shuffle=True, pin_memory=True,
                            batch_size=config.batch_size,
                            num_workers=config.num_workers)

        for _ in range(n_epochs):

            model.eval()
            with torch.no_grad():
                evaluator(test_times=1, dataset=dataset, gpt=representation, prior=model, env=dataset.env, device=self.device, rtg_c=config.rtg)
            
            losses = []
            timer = Timer()
            for it, batch in enumerate(loader):
                model.train(True)
                y = batch.next_traj.to(self.device)
                self.n_tokens += np.prod(y.shape)
                if self.n_tokens < config.warmup_tokens:
                    lr_mult = float(self.n_tokens) / float(max(1, config.warmup_tokens))
                else:
                    progress = float(self.n_tokens - config.warmup_tokens) / float(
                        max(1, config.final_tokens - config.warmup_tokens))
                    lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))

                if config.lr_decay:
                    lr = config.learning_rate * lr_mult
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                else:
                    lr = config.learning_rate

                
                states = batch.traj[:, :representation.history_horizon + 1, 1:model.joint_dim].to(self.device)                  
                gt_idx = representation.encode_to_idx(batch.traj[:, :, :model.joint_dim].to(self.device), 
                                                      batch.terminal.to(self.device))
                rtg = batch.traj[:, 0, 0].to(self.device)                              
                terminal = batch.terminal.to(self.device)
                mask = batch.mask.to(self.device)
                with torch.set_grad_enabled(True):
                    _, transformer_loss, idx, pred_traj, pred_actions = model(rtg, states, gt_idx, representation, temp)
                    losses.append(transformer_loss.item())
                    inv_loss = model.inv_loss(batch.traj[:, :-1, 1:model.joint_dim].to(self.device), 
                                              batch.traj[:, 1:, 1:model.joint_dim].to(self.device), 
                                              batch.traj[:, :-1, model.joint_dim:-1].to(self.device)) 
                    pred_loss = model.pred_loss(batch.traj[..., :model.joint_dim].to(self.device),
                                                pred_traj, 
                                                mask, terminal, 
                                                representation.padding_vector.clone().detach().to(dtype=torch.float32, device=self.device))
                    action_loss = model.action_loss(batch.traj[:, 0, model.joint_dim:-1].to(self.device), pred_actions)
                    loss = (1 / 2) * (inv_loss + transformer_loss) + pred_loss + action_loss
                    
                model.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
                optimizer.step()

                if it % log_freq == 0:
                    summary = dict(transformer_loss=transformer_loss.item(),
                                   inv_loss = inv_loss.item(),
                                   loss = loss.item(),
                                   pred_loss = pred_loss.item(),
                                   action_loss = action_loss.item(),
                                   lr=lr,
                                   lr_mulr=lr_mult, )
                    print(
                        f'[ utils/training ] epoch {self.n_epochs} [ {it:4d} / {len(loader):4d} ] ',
                        f' transformer loss {transformer_loss.item():.4f} |'
                        f' inv_loss {inv_loss.item():.4f} |'
                        f' pred_loss {pred_loss.item():.4f} |'
                        f' action_loss {action_loss.item():.4f} |'
                        f' loss {loss.item():.4f} |'
                        f' temp {temp:.4f} |'
                        f' t: {timer():.2f}')
                    wandb.log(summary, step=self.n_epochs * len(loader) + it)
            self.n_epochs += 1

            model.eval()
            with torch.no_grad():
                avgr, stder, avgs, stdes = evaluator(test_times=10, dataset=dataset, gpt=representation, prior=model, env=dataset.env, device=self.device, rtg_c=config.rtg)
            model.train(True)
        
        return avgs, stdes, temp