import torch
import numpy as np
from tqdm.auto import tqdm

from einops import rearrange

# Control tqdm output
def type_of_script():
    try:
        ipy_str = str(type(get_ipython()))
        if 'zmqshell' in ipy_str:
            return 'jupyter'
        if 'terminal' in ipy_str:
            return 'ipython'
    except:
        return 'terminal'


def train_model(model, optimizer, scheduler, dataloaders_by_split, 
                criterions, max_epochs, args, 
                input_transform=None, output_transform=None,
                wandb=None, return_best=False, early_stopping_epochs=10):
    
    args.best_val_metric = 1e10
    args.best_val_metric_epoch = -1
    args.best_train_metric = 1e10  # Interpolation / fitting also good to test
    args.best_train_metric_epoch = -1
    
    # Experiment with C coeffs
    args.learned_c_weights = []

    pbar = tqdm(range(max_epochs))
    
    if input_transform is None:
        input_transform = lambda x: x
        output_transform = lambda x: x
        
    early_stopping_count = 0

    for epoch in pbar:
        if epoch == 0:
            pbar.set_description(f'Epoch: {epoch}')
        else:
            description = f'Epoch: {epoch}'  # Display metric * 1e3
            description += f' | Best val RMSE (1e-3): {args.best_val_metric * 1e3:.3f} (epoch = {args.best_val_metric_epoch:3d})'
            for split in metrics:
                if split != 'test':  # No look
                    for metric_name, metric in metrics[split].items():
                        if metric_name != 'total':
                            description += f' | {split}/{metric_name} (1e-3): {metric * 1e3:.3f}'
            pbar.set_description(description)

        _, metrics = run_epoch(model, dataloaders_by_split, optimizer, scheduler, 
                               criterions, args, epoch, input_transform, output_transform)
        
        # Reset early stopping count if epoch improved
        if args.best_val_metric_epoch == epoch:  
            early_stopping_count = 0
        else:
            early_stopping_count += 1
        
        if wandb is not None:
            log_metrics = {}
            for split in metrics.keys():
                for k, v in metrics[split].items():
                    log_metrics[f'{split}/{k}'] = v
            wandb.log(log_metrics, step=epoch)
            
        if early_stopping_count == early_stopping_epochs:
            print(f'Early stopping at epoch {epoch}...')
            break  # Exit for loop and do early stopping
        
    print(f'-> Saved best val model checkpoint at epoch {args.best_val_metric_epoch}!')
    print(f'-> Saved best train model checkpoint at epoch {args.best_train_metric_epoch}!')
    
    if return_best:
        best_model_dict = torch.load(args.best_val_checkpoint_path)
        best_epoch = best_model_dict['epoch']
        print(f'Returning best val model from epoch {best_epoch}')
        model.load_state_dict(best_model_dict['state_dict'])
        
    return model


def run_epoch(model, dataloaders, optimizer, scheduler, criterions, 
              args, epoch, input_transform=None, output_transform=None,
              wandb=None):
    # dataloaders is {'train': train_loader, 'val': val_loader, 'test': test_loader}
    metrics = {split: None for split in dataloaders.keys()}
    
    for split, dataloader in dataloaders.items():
        # Adjustment
        try:
            mean = dataloader.dataset.standardization['means'][0]
            std  = dataloader.dataset.standardization['stds'][0]
        except AttributeError:
            mean = 0.
            std = 1.
          
        _shared_step = shared_step_with_control_k
        
        model, _metrics = _shared_step(model, dataloader, optimizer, scheduler, 
                                       criterions, args.device, epoch, 
                                       prefix=split, args=args,
                                       input_transform=input_transform,
                                       output_transform=output_transform)
        metrics[split] = _metrics
        metrics[split]['rmse'] = np.sqrt(_metrics['mse'] / _metrics['total'])
        metrics[split]['mse'] = _metrics['mse'] / _metrics['total']
        metrics[split]['mae'] = _metrics['mae'] / _metrics['total']
        
    if metrics['val']['rmse'] < args.best_val_metric or epoch == 0:
        args.best_val_metric = metrics['val']['rmse']
        args.best_val_metric_epoch = epoch
        torch.save({'epoch': epoch,
                    'rmse': metrics['val']['rmse'],
                    'state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'c_weights': args.learned_c_weights},
                    args.best_val_checkpoint_path)
    
        
    if metrics['train']['rmse'] < args.best_train_metric:
        args.best_train_metric = metrics['train']['rmse']
        args.best_train_metric_epoch = epoch
        torch.save({'epoch': epoch,
                    'rmse': metrics['train']['rmse'],
                    'state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'c_weights': args.learned_c_weights},
                    args.best_train_checkpoint_path)
        
    if args.scheduler == 'plateau':
        scheduler.step(metrics['val']['rmse'])
    elif args.scheduler == 'timm_cosine':
        scheduler.step(epoch)
        
    try:
        args.learned_c_weights.append(model.nn[0].ssm.c_shift[:, 0, :].detach().cpu().numpy()[0])
    except AttributeError:
        pass
    return model, metrics


def free_gpu(x):
    x = x.cpu()
    del x


def shared_step_with_control_k(model, dataloader, optimizer, scheduler, criterions, 
                               device, epoch, prefix='train', args=None,
                               input_transform=None, output_transform=None):
    
    # Feedback loss
    if 'mse' in args.loss:
        criterion_feedback = torch.nn.MSELoss(reduction='mean')
    else:
        criterion_feedback = torch.nn.L1Loss(reduction='mean')
    
    metrics = {'total': 0.}
    for k in criterions.keys():
        metrics[k] = 0. 
        
    if prefix == 'train':
        model.train()
        model.zero_grad()
        grad_enabled=True
    else:
        model.eval()
        grad_enabled=False
        
    model.to(device)
    
    model.joint_train()
    
    horizon_ends = [args.horizon] * len(dataloader)
        
    default_state_ix = model.d_state
    
    pbar = tqdm(dataloader, leave=False, desc=f'{prefix} with control: horizon = {horizon_ends[0]}') if type_of_script() == 'terminal' else dataloader
    
    with torch.set_grad_enabled(grad_enabled):
        for batch_ix, data in enumerate(pbar):
            x, y, *z = data  # z holds arguments such as sequence length
            
            # For multivariate
            if args.features == 'M':
                x = rearrange(x,'b l d -> (b d) l').unsqueeze(-1)
                y = rearrange(y,'b l d -> (b d) l').unsqueeze(-1)
            
            # Transform batch data 
            x = input_transform(x)
            
            horizon = horizon_ends[batch_ix]
            
            x = x[:, :-args.horizon, :].to(device)
            y = y.to(device)

            # First process lag terms
            if grad_enabled:
                freeze_weights = True if args.replicate in [32, 34] else False
                model.process_lag(freeze_weights=freeze_weights)

            r_args = (args.lag, horizon, args.lag)

            loss = 0
            y_r_pred = output_transform(model(x, r_args)) 
            y_r_true = torch.cat((x[:, r_args[2]:, :], y[:, :r_args[1], :]), dim=1)  
            # Make equal for now
            if 'rmse' in args.loss:
                loss += torch.sqrt(criterions['mse'](y_r_pred, y_r_true)).mean() 
            elif 'mae' in args.loss:
                loss += (criterions['mae'](y_r_pred, y_r_true)).mean()
            else:
                loss += (criterions['mse'](y_r_pred, y_r_true)).mean()
            # Free gpu
            free_gpu(y_r_pred); free_gpu(y_r_true)

            if grad_enabled:
                feedback, reference = model.get_feedback()
                if 'rmse' in args.loss:
                    loss += torch.sqrt(criterion_feedback(feedback.flatten(), reference.flatten())) 
                elif 'mae' in args.loss:
                    loss += criterion_feedback(feedback.flatten(), reference.flatten())
                else:
                    loss += criterion_feedback(feedback.flatten(), reference.flatten())


                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                model.process_horizon()  # clear out Kx terms

            y_h_pred = y_r_pred[:, -horizon:, :]
            y_h_true = y_r_true[:, -horizon:, :]

            # Save metrics  (standardized for now)
            y_h_pred = y_h_pred.cpu()
            y_h_true = y_h_true.cpu()
        y = y.cpu()
        for k, criterion in criterions.items():
            metrics[k] += (criterion(y_h_pred, y_h_true).mean() * len(y_h_true)).item()
        metrics['total'] += len(y)

        x = x.cpu()
        torch.cuda.empty_cache()

        
    return model, metrics

