import hydra
import omegaconf
import os
import yaml
import numpy as np
import tqdm
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from dataset import DynamicsDataset
from meter import AverageMeter

def optimize(
    idx_compress_round,
    config, # yaml config
    model, # model object
    global_iteration,
    writer, # summary writer
    n_remain):

    datasets = {}
    dataloaders = {}
    data_n_batches = {}
    for phase in ['train', 'valid']:

        datasets[phase] = DynamicsDataset(
            config,
            phase=phase)

        dataloaders[phase] = DataLoader(
            datasets[phase],
            batch_size=config['train']['batch_size'],
            shuffle=True if phase == 'train' else False,
            num_workers=config['train']['num_workers'],
            drop_last=True)

        data_n_batches[phase] = len(dataloaders[phase])

    # criterion
    MSELoss = nn.MSELoss()
    L1Loss = nn.L1Loss()

    # optimizer
    params = model.parameters()
    lr = float(config['train']['lr'])
    optimizer = optim.Adam(params, lr=lr, betas=(config['train']['lr_params']['adam_beta1'], 0.999))

    # setup scheduler
    sc = config['train']['lr_scheduler']
    scheduler = None
    if config['train']['lr_scheduler']['enabled']:
        if config['train']['lr_scheduler']['type'] == "ReduceLROnPlateau":
            scheduler = ReduceLROnPlateau(
                optimizer,
                mode='min',
                factor=sc['factor'],
                patience=sc['patience'],
                threshold_mode=sc['threshold_mode'],
                cooldown= sc['cooldown'],
                verbose=True)
        elif config['train']['lr_scheduler']['type'] == "StepLR":
            step_size = config['train']['lr_scheduler']['step_size']
            gamma = config['train']['lr_scheduler']['gamma']
            scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
        else:
            raise ValueError("unknown scheduler type: %s" %(config['train']['lr_scheduler']['type']))


    if torch.cuda.is_available():
        print("using gpu")
        model = model.cuda()

    best_valid_loss = np.inf

    activations = []

    if idx_compress_round == 0:
        epochs = config['train']['n_epoch_initial']
    else:
        epochs = config['train']['n_epoch_prune']
    for epoch in range(epochs):
        phases = ['train', 'valid']

        writer.add_scalar("Training Params/epoch", epoch, global_iteration)

        for phase in phases:
            model.train(phase == 'train')

            # set up recording metrics
            meter_loss = AverageMeter()
            meter_loss_rmse = AverageMeter()
            meter_loss_reg = AverageMeter()
            if config['train']['use_sp_loss']:
                meter_loss_sp = AverageMeter()

            loader = dataloaders[phase]

            for i, data in enumerate(loader):

                loss_container = dict() # store the losses for this step

                global_iteration += 1

                with torch.set_grad_enabled(phase == 'train'):
                    n_his, n_roll = config['train']['n_history'], config['train']['n_rollout']
                    n_samples = n_his + n_roll

                    # [B, n_samples, obs_dim]
                    observations = data['observations']

                    # [B, n_samples, action_dim]
                    actions = data['actions']
                    B = actions.shape[0]

                    if torch.cuda.is_available():
                        observations = observations.cuda()
                        actions = actions.cuda()

                    # states, actions = data
                    assert actions.shape[1] == n_samples
                    loss_mse = 0.

                    # we don't have any visual observations, so states are observations
                    states = observations

                    # [B, n_his, state_dim]
                    state_init = states[:, :n_his]

                    # We want to rollout n_roll steps
                    # [B, n_his + n_roll - 1, action_dim]
                    action_seq = actions[:, :-1]

                    # try using models_dy.rollout_model instead of doing this manually
                    rollout_data = model.rollout_model(
                        state_init=state_init,
                        action_seq=action_seq)

                    # [B, n_roll, state_dim]
                    state_rollout_pred = rollout_data['state_pred']

                    # [B, n_roll, state_dim]
                    state_rollout_gt = states[:, n_his:]

                    if epoch == epochs - 1:
                        activations.append(rollout_data['activation'])

                    # the loss function is between
                    # [B, n_roll, state_dim]
                    state_pred_err = state_rollout_pred - state_rollout_gt

                    # everything is in 3D space now so no need to do any scaling
                    # all the losses would be in meters . . . .
                    loss_mse = MSELoss(state_rollout_pred, state_rollout_gt)
                    loss_l1 = L1Loss(state_rollout_pred, state_rollout_gt)

                    # L1 regularization loss
                    loss_reg = 0.
                    n_param = 0.
                    for ii, W in enumerate(list(model.model.parameters())):
                        if ii % 2 == 0: # only do this for the weights
                            loss_reg += W.norm(1)
                            n_param += W.numel() # number of elements in input tensor
                    loss_reg /= n_param

                    loss = loss_mse + loss_reg * float(config['train']['lam_l1_reg'])

                    if config['train']['use_sp_loss']:
                        # use the sparse loss by minimizing the prob of class 1
                        loss_sp = 0.
                        n_reg = 0.
                        for ii, m in enumerate(model.mask_prob):
                            mask_reg = (model.mask[ii] == 1)[0] # only calculate this loss for places that are still Relu
                            p = F.softmax(m[mask_reg], dim=-1)
                            l = torch.sum(p[:, 1])   # the sum of prob on class 1
                            if config['train']['minimize_ID']:
                                l += torch.sum(p[:, 2]) * config['train']['ID_loss_weight']
                            n_reg += p.shape[0]
                            loss_sp += l

                        # if prob on class 1 already smaller than n_remain Relu requirements, than this loss becomes 0
                        loss_sp = F.relu(loss_sp - n_remain) / n_reg

                        loss += loss_sp * float(config['train']['lam_sp_loss'])

                        loss_container['sp'] = loss_sp

                        meter_loss_sp.update(loss_sp.item(), B)

                    meter_loss.update(loss.item(), B)
                    meter_loss_rmse.update(np.sqrt(loss_mse.item()), B)
                    meter_loss_reg.update(loss_reg.item(), B)

                    # compute losses at final step of the rollout
                    mse_final_step = MSELoss(state_rollout_pred[:, -1, :], state_rollout_gt[:, -1, :])
                    l2_final_step = torch.norm(state_pred_err[:, -1], dim=-1).mean()
                    l1_final_step = L1Loss(state_rollout_pred[:, -1, :], state_rollout_gt[:, -1, :])

                    loss_container['mse'] = loss_mse
                    loss_container['l1'] = loss_l1
                    loss_container['mse_final_step'] = mse_final_step
                    loss_container['l1_final_step'] = l1_final_step
                    loss_container['l2_final_step'] = l2_final_step
                    loss_container['reg'] = loss_reg

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                if i % config['log']['log_per_iter'] == 0 and i != 0:
                    log = 'Epoch: [%d/%d] Phase: %s Step:[%d/%d] Global Iter: %d LR: %.6f' % (
                        epoch, epochs, phase, i, data_n_batches[phase], global_iteration,
                        optimizer.param_groups[0]['lr'])
                    log += ', loss: %.6f (%.6f)' % (
                        loss.item(), meter_loss.avg)
                    log += ', rmse: %.6f (%.6f)' % (
                        np.sqrt(loss_mse.item()), meter_loss_rmse.avg)
                    log += ', reg: %.6f (%.6f)' % (
                        loss_reg.item(), meter_loss_reg.avg)
                    if config['train']['use_sp_loss']:
                        log += ', sp: %.6f (%.6f)' % (
                            loss_sp.item(), meter_loss_sp.avg)
                    print(log)

                    # log data to tensorboard
                    writer.add_scalar("Params/learning rate", optimizer.param_groups[0]['lr'], global_iteration)

                    for loss_type, loss_obj in loss_container.items():
                        plot_name = "Loss/%s/%s" % (loss_type, phase)
                        writer.add_scalar(plot_name, loss_obj.item(), global_iteration)

            if phase == "train":
                if (scheduler is not None) and (config['train']['lr_scheduler']['type'] == "StepLR"):
                    scheduler.step()

            if phase == 'valid':
                if meter_loss.avg < best_valid_loss:
                    best_valid_loss = meter_loss.avg
                    model.save_model('best_valid_model')

            writer.flush() # flush SummaryWriter events to disk

    return model, global_iteration, activations
