import os
from tqdm import tqdm

import numpy as np

import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader


from utils import common, logs_handler, misc
from utils.metrics import MeanAccumulator, MixedAccumulator, MetricsTracer, print_format

logger = logs_handler.get_logger(__name__)

class Trainer:

    def __init__(self, stage, cfg, model, evaluator, envs, optimizer=None, scheduler=None):

        self.metrics_keys = ['loss', 'state_rep_loss', 'cgpt_loss', 'bc_loss']

        self.stage = stage
        self.cfg = cfg
        self.device = cfg.device
        self.device_ids = cfg.device_ids
        self.model = model.to(self.device)
        self.data_distributed_model = nn.DataParallel(self.model, device_ids=self.device_ids)

        self.envs = envs if isinstance(envs, list) else [envs]
        self.evaluator = evaluator
        
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.scaler = GradScaler()
                
        self.loss_accum = MixedAccumulator(accumulators={key: MeanAccumulator(ignore_zero=True) for key in self.metrics_keys})
        self.metrics_tracer = MetricsTracer(metrics_keys=self.metrics_keys, index_key='epoch')
                
    def train_step(self, train_dataset, max_num_batches=None, epoch=None, num_epochs=None, 
                   grad_accum_step=1, scheduler_step=None, max_grad_norm=None, 
                   freeze_at=None, unfreeze_at=None, freeze_all=False, unfreeze_all=False):

        assert self.optimizer is not None, '...'
        batch_size = self.cfg.batch_size
        num_workers = self.cfg.num_workers
        self.data_distributed_model.train(True)
        if (freeze_at is not None) and (epoch == freeze_at):
            self.data_distributed_model.module.freeze(freeze_all=freeze_all)
            logger.info(f'freezed - epoch: {epoch}/x')
        elif (unfreeze_at is not None) and (epoch == unfreeze_at):
            self.data_distributed_model.module.unfreeze(unfreeze_all=unfreeze_all)
            logger.info(f'unfreezed - epoch: {epoch}/x')

        self.loss_accum.reset()
              
        data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                                 pin_memory=True, num_workers=num_workers, drop_last=True)
        data_loader = iter(data_loader)

        num_steps = min(max_num_batches or len(data_loader), len(data_loader)) 
        scheduler_step = scheduler_step or num_steps
        pbar = tqdm(range(1, num_steps + 1), position=0, leave=True)
        metrics, lr = None, None
        for episode in pbar:
            states, targets = next(data_loader)                
            actions = targets['action'].to(self.device)
            
            if train_dataset.use_strl:
                states1 = torch.as_tensor(states['augmented'], dtype=torch.float32, device=self.device) 
                states2 = torch.as_tensor(states['augmented_prime'], dtype=torch.float32, device=self.device)
                with autocast():
                    outputs = self.data_distributed_model(stage=self.stage, states=states1, states_prime=states2, 
                                                          actions=actions)
            else:
                states = torch.as_tensor(states, dtype=torch.float32, device=self.device)
                with autocast():
                    outputs = self.data_distributed_model(stage=self.stage, states=states, actions=actions)
                    
            loss = outputs.get('loss') / grad_accum_step
            if loss.ndim != 0:
                loss = loss.mean()
            if self.scaler is not None:
                loss = self.scaler.scale(loss)
            loss.backward()
            common.optimizer_update_step(self.model, self.optimizer, self.scaler, grad_accum_step, episode, 
                                            num_steps, max_grad_norm)
            lr = common.scheduler_update_step(self.scheduler, scheduler_step, episode, num_steps)
            
            with torch.no_grad():
                use_cache = (episode == len(pbar))
                metrics = self.loss_accum.accumulate(values=outputs) 
                metrics_logs = self.metrics_tracer.get_logs(metrics, use_cache=use_cache, index=epoch)
                misc.wandb_log(metrics)
            log = ''
            if lr is not None:
                log += f'{print_format({"lr": lr}, use_float=False)}'
            if epoch is not None:
                log += ' -- epoch: {}/{}'.format(epoch, num_epochs or 'x')
            log += ' | {}'.format(metrics_logs)
            pbar.set_description(log)
        pbar.close()
        return metrics
    
    def evaluate_step(self, trials, reduce_method=None, epoch=None, summary=False):
        self.data_distributed_model.train(False)
        reduce_method = reduce_method or np.mean
        use_id = len(self.envs) > 1
        for env in self.envs:
            self.evaluator.set_env(env)
            _ = self.evaluator.evaluate(num_trials=trials, reduce_method=reduce_method, 
                                        epoch=epoch, progress=True, summary=summary, use_id=use_id)
    
    def train(self, train_dataset, num_epochs, max_num_batches=None, grad_accum_step=1, scheduler_step=None, 
              max_grad_norm=None, freeze_at=None, unfreeze_at=None, freeze_all=False, unfreeze_all=False, 
              eval_cfg=None, csv_path=None, ckpt_path=None, mode=None, monitor=None):

        mode = mode or min
        monitor = monitor or 'loss'

        optim_value = float('inf') if (mode == min) else float('-inf')
        for epoch in range(1, num_epochs + 1):
            metrics = self.train_step(train_dataset, max_num_batches, epoch, num_epochs, 
                                      grad_accum_step, scheduler_step, max_grad_norm, 
                                      freeze_at, unfreeze_at, freeze_all, unfreeze_all)
            misc.wandb_log({'epoch': epoch})
            if (eval_cfg is not None) and (eval_cfg['enabled']) and\
                ((epoch % eval_cfg['step'] == 0) or (epoch == num_epochs)\
                    or ((epoch == 1) and eval_cfg['include_first_step'])):
                # summary = ((epoch + eval_cfg['step']) > num_epochs)
                self.evaluate_step(eval_cfg['trials'], eval_cfg['reduce_method'], 
                                   epoch=epoch, summary=False)
            new_optim_value = mode(optim_value, metrics[monitor])
            if (ckpt_path is not None) and (abs(new_optim_value - metrics[monitor]) <= 1e-7):
                optim_value = new_optim_value
                logger.info(f'optim_{monitor} = {new_optim_value}')
                self.save_checkpoint(ckpt_path)

        if (ckpt_path is not None) and (not os.path.exists(ckpt_path)):
            logger.info(f'ckpt got created!')
            self.save_checkpoint(ckpt_path)
            
        if csv_path is not None:
            self.metrics_tracer.export(csv_path)

    def save_checkpoint(self, path):
        model_state = self.model.state_dict()
        torch.save(model_state, path)

    def load_checkpoint(self, path, strict=True, warn=False):
        device = torch.device(self.device) if self.device else None
        warn_msgs =\
            common.load_checkpoint(self.model, path=path, strict=strict,
                                   device=device, warn=warn)
        for msg in warn_msgs:
            logger.warn(msg)
