"""
Train a model on a regression problem.
"""
# comet_ml integrations removed in anonymized copy
from collections import defaultdict
import os
import pickle as pkl
import random

import collections
import collections.abc
for type_name in collections.abc.__all__:
    setattr(collections, type_name, getattr(collections.abc, type_name))
# workaround

from attrdict import AttrDict
import git
import hydra
import numpy as np
from omegaconf import OmegaConf, open_dict
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from krt import KRT_PATH
from krt.utils import NoSchedule
import time

@hydra.main(config_path=f'{KRT_PATH}/cfgs/regression',
            config_name='train', version_base='1.1')
def train(cfg):
    if cfg.get('debug', False):
        breakpoint()
    random.seed(cfg['seed'])
    np.random.seed(cfg['seed'])
    torch.manual_seed(cfg['seed'])
    # Instantiate data module.
    data = hydra.utils.instantiate(
        cfg['data'],
        seed=cfg['seed'],
    )
    with open_dict(cfg):
        cfg['model']['architecture']['dim_x'] = 5
        cfg['model']['architecture']['dim_y'] = 5
    # Instantiate data module.
    model = hydra.utils.instantiate(
        cfg['model']['architecture'],
    )
    # Load in other parameters.
    device = ('cpu' if cfg.get('cuda_device', None) is None
              else f'cuda:{cfg["cuda_device"]}')
    model = model.to(device)
    checkpoint_every = cfg.get('checkpoint_every', None)
    ctx_sampling_binomial = cfg.get('ctx_sampling_binomial', None)
    learning_rate, weight_decay, lr_schedule = (
        cfg['model']['training']['learning_rate'],
        cfg['model']['training']['weight_decay'],
        cfg['model']['training']['lr_schedule']
    )
    # Number of context sizes to use when validating or testing. This is a fixed grid.
    epochs = cfg['data']['epochs']
    max_train_hours = cfg['data'].get('max_train_hours', None)
    print('max train hours:', max_train_hours)
    test_normalize = cfg['data'].get('test_normalize', True)
    early_stop_patience = cfg['data'].get('early_stop_patience', float('inf'))
    min_tr_ctx_size = cfg['data']['min_tr_ctx_size']
    min_tr_trg_size = cfg['data']['min_tr_trg_size']
    min_val_ctx_size = cfg['data'].get('min_val_ctx_size', 3)
    num_val_ctx_sizes = cfg['data'].get('num_val_ctx_sizes', 5)
    min_te_ctx_size = cfg['data'].get('min_te_ctx_size', 1)
    max_te_ctx_size = cfg['data'].get('max_te_ctx_size', None)
    val_every = cfg['data'].get('val_every', 5)
    if data.val_data is None:
        ctx_sizes = []
    else:
        L = data.L
        ctx_sizes = [
            int(s) for s in np.linspace(min_val_ctx_size, L - 1, num_val_ctx_sizes)]
    log_joint_in_val = cfg.get('log_joint_in_val', False)
    # Set up optimizer.
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay,
    )
    if lr_schedule:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=epochs * data.batches_per_epoch)
    else:
        scheduler = NoSchedule()
    # Log information and set up logger.
    OmegaConf.save(cfg, os.path.join(os.getcwd(), 'config.yaml'))
    with open(os.path.join(os.getcwd(), 'version.txt'), 'w') as f:
        f.write(str(git.Repo(search_parent_directories=True).head.object.hexsha))
    logger = SummaryWriter(os.getcwd())
    pbar = tqdm(total=epochs)
    if checkpoint_every is not None:
        os.makedirs(os.path.join(os.getcwd(), 'checkpoints'))
 
    # comet logging removed in anonymized copy

    time_start = time.time()
    # Train!
    best_ep, best_loss = None, float('inf')
    step_count = 0
    for ep in range(epochs):
        # Train batch update.
        tr_stats = defaultdict(float)

        for xi, yi in data.train_data:
            xi, yi = xi.to(device), yi.to(device)
            datai = torch.cat([xi, yi], dim=-1)
            if ctx_sampling_binomial is None:
                num_ctx = random.randint(min_tr_ctx_size,
                                        xi.shape[1] - min_tr_trg_size)
                num_elems = datai.numel()
                flat_mask = torch.zeros(num_elems, dtype=torch.float32, device=datai.device)
                perm = torch.randperm(num_elems, device=datai.device)
                flat_mask[perm[:num_ctx]] = 1.0
                mask = flat_mask.view_as(datai).to(device)
                print(mask.shape, 'mask')
            else:
                mask = torch.from_numpy(np.random.binomial(1,
                                             ctx_sampling_binomial, size=(xi.shape[0], xi.shape[1], 1))).to(device)
                if mask.sum() == 0:
                    mask[0, random.randint(0, xi.shape[1]-1), 0] = 1.0

            xc, xt = mask, (1 - mask).to(device)
            yc, yt = datai, datai
            batch = AttrDict({'xc': xc, 'xt': xt, 'yc': yc, 'yt': yt,
                              'x': xi, 'y': yi})
            model_out = model.forward(batch)
            loss_out = model.loss(batch, model_out)
            optimizer.zero_grad()
            loss_out.backward()
            optimizer.step()
            scheduler.step()
            step_count += 1
            tr_stats['Loss'] += loss_out.item()
            if step_count %100 == 0:
                model.eval()
                val_stats = defaultdict(float)
                # Regular Loss.
                for xi, yi in data.val_data:
                    xi, yi = xi.to(device), yi.to(device)
                    datai = torch.cat([xi, yi], dim=-1)
                    # print(datai)
                    # print(datai.shape, 'data shape')
                    # num_ctx = random.randint(min_tr_ctx_size, xi.shape[1] - 1)
                    if ctx_sampling_binomial is None:
                        num_ctx = random.randint(min_tr_ctx_size,
                                                xi.shape[1] - min_tr_trg_size)
                        num_elems = datai.numel()
                        flat_mask = torch.zeros(num_elems, dtype=torch.float32, device=datai.device)
                        perm = torch.randperm(num_elems, device=datai.device)
                        flat_mask[perm[:num_ctx]] = 1.0
                        mask = flat_mask.view_as(datai).to(device)
                    else:
                        mask = torch.from_numpy(np.random.binomial(1,
                                                    ctx_sampling_binomial, size=(xi.shape[0], xi.shape[1], 1))).to(device)
                    xc, xt = mask, (1 - mask).to(device)
                    yc, yt = datai, datai

                    # print(xc.shape, xt.shape, yc.shape, yt.shape, 'xc, xt, yc, yt shapes')
                    batch = AttrDict({'xc': xc, 'xt': xt, 'yc': yc, 'yt': yt,
                                      'x': xi, 'y': yi})
                    with torch.no_grad():
                        model_out = model.forward(batch)
                    loss_out = model.loss(batch, model_out)
                    val_stats['Loss'] += loss_out.item() / (len(data.val_data))
                    val_ll = model.seq_ll(xc, yc, xt, yt).sum().item()
                    val_stats['LL'] += val_ll / (len(data.val_data))
                    # if 'stats' in model_out:
                    #     for k, v in model_out['stats'].items():
                    #         val_stats[k] += v / (len(data.val_data))
                    # if 'stats' in loss_out:
                    #     for k, v in loss_out.stats.items():
                    #         val_stats[k] += v / (len(data.val_data))
                # comet logging removed in anonymized copy
                print(f"Ep {ep} Step {step_count} Val LL: {val_stats['LL']:0.2f}") #comment this if comet
                if log_joint_in_val:
                    # Thorough Logging.
                    for cs in ctx_sizes:
                        total_nll = 0.0
                        for xi, yi in data.val_data:
                            xi, yi = xi.to(device), yi.to(device)
                            xc, xt = xi[:, :cs], xi[:, cs:]
                            yc, yt = yi[:, :cs], yi[:, cs:]
                            ll = model.seq_ll(xc, yc, xt, yt)
                            total_nll -= ll.sum().item() / data.num_val
                        val_stats[f'NLL_CTX{cs}'] = total_nll
                        val_stats[f'NLL_CTX{cs}_PerPt'] = total_nll / xt.shape[1]
                    val_stats['NLL'] = np.mean([v for k, v in val_stats.items()
                                                if 'PerPt' not in k and 'Loss' not in k])
                    val_stats['NLL_PerPt'] = np.mean([v for k, v in val_stats.items()
                                                      if 'PerPt' in k])
                for k, v in val_stats.items():
                    logger.add_scalar(f'val/{k}', v, ep)
                if val_stats['Loss'] < best_loss:
                    if best_ep is not None:
                        old_path = os.path.join(os.getcwd(),
                                                f'best_ep_{best_ep}.pt')
                        os.system(f'rm {old_path}')
                    best_loss = val_stats['Loss']
                    best_ep = ep
                    torch.save(model.state_dict(), os.path.join(os.getcwd(),
                                                                f'best_ep_{best_ep}.pt'))
                elif ep - best_ep > early_stop_patience:
                    break
                model.train()                

            # if 'stats' in model_out:
            #     for k, v in model_out['stats'].items():
            #         tr_stats[k] += v
            # if 'stats' in loss_out:
            #     for k, v in loss_out.stats.items():
            #         tr_stats[k] += v
        for k, v in tr_stats.items():
            logger.add_scalar(f'tr/{k}', v / data.train_num_batches, ep)
            tr_stats[k] = v / data.train_num_batches
        if best_loss < float('inf'):
            pbar.set_postfix_str(
                f'BestEp: {best_ep} '
                f'Train Loss: {tr_stats["Loss"]:0.2f} '
                f'Best Val: {best_loss:0.2f} '
            )
        else:
            pbar.set_postfix_str(
                f'Train Loss: {tr_stats["Loss"]:0.2f} '
            )
        # Possibly save off the current model weights.
        if checkpoint_every is not None and ep % checkpoint_every == 0:
            torch.save(model.state_dict(), os.path.join(os.getcwd(),
                                                        'checkpoints',
                                                        f'ep_{ep}.pt'))
        # Update progress.
        pbar.update(1)
        time_new = time.time()
        if max_train_hours is not None and (time_new - time_start)/3600 > max_train_hours:
            print(f"Stopping training since {max_train_hours} hours have passed.")
            break
    torch.save(model.state_dict(), os.path.join(os.getcwd(), 'final.pt'))
    pbar.close()
    # If there is a test set, evaluate and log the results.
    if data.num_te > 0:
        if best_ep is not None:
            model.load_state_dict(torch.load(os.path.join(os.getcwd(),
                                                          f'best_ep_{best_ep}.pt')))
        te_stats = {}
        L = data.L
        if max_te_ctx_size is None:
            max_te_ctx_size = L - 1
        ctx_sizes = [s for s in range(min_te_ctx_size, max_te_ctx_size + 1)]
        for cs in tqdm(ctx_sizes, desc='Testing'):
            total_ll = 0.0
            total_min = 0.0
            total_max = 0.0
            if test_normalize:
                for xi, yi, ci, mi in data.test_data:
                    # print(xi.shape, yi.shape, ci.shape, mi.shape, 'xi, yi, ci, mi shapes (test data)')
                    xi, yi = xi.to(device), yi.to(device)
                    ci, mi = ci.to(device), mi.to(device)
                    xc, xt = xi[:, :cs], xi[:, cs:]
                    yc, yt = yi[:, :cs], yi[:, cs:]
                    total_ll += model.seq_ll(xc, yc, xt, yt).sum().item()
                    total_min += mi[:, cs:].sum().item()
                    total_max += (ci[:, -1] - ci[:, cs - 1]).sum().item()
                total_ll = (total_ll - total_min) / (total_max - total_min) * 100
            else:
                i = 0
                for xi, yi in data.test_data:
                    i += 1
                    xi, yi = xi.to(device), yi.to(device)
                    datai = torch.cat([xi, yi], dim=-1)
                    num_elems = datai[:, :, -1:].numel()
                    flat_mask = torch.zeros(num_elems, dtype=torch.float32, device=datai.device)
                    perm = torch.randperm(num_elems, device=datai.device)
                    flat_mask[perm[:cs]] = 1.0
                    mask = flat_mask.view_as(datai[:, :, -1:]).to(device)
                    xc, xt = mask, (1 - mask).to(device)
                    yc, yt = datai, datai
                    # print(xc.shape, xt.shape, yc.shape, yt.shape, 'xc, xt, yc, yt shapes (test data)')
                    total_ll += model.seq_ll(xc, yc, xt, yt).sum().item()
                total_ll = total_ll / i
            te_stats[f'Normalized_LL_Ctx{cs}'] = total_ll
        te_stats['Normalized_LL'] = np.mean([v for v in te_stats.values()])
        print('=' * 20)
        for k, v in te_stats.items():
            print(f'{k} : {v:0.2f}')
        print('=' * 20)
        with open(os.path.join(os.getcwd(), 'test_stats.pkl'), 'wb') as f:
            pkl.dump(te_stats, f)


if __name__ == '__main__':
    train()
