import datasets
import transformers
from datasets import load_dataset, concatenate_datasets
from transformers import InformerConfig, InformerForPrediction, PretrainedConfig
from accelerate import Accelerator, DistributedDataParallelKwargs
from torch.cuda.amp import autocast


import torch
from torch.optim import AdamW
from torchinfo import summary

from gluonts.time_feature import month_of_year

import pandas as pd
import numpy as np
import pdb
import wandb
from tqdm.auto import tqdm
import argparse
import yaml
from pathlib import Path
import shutil
from datetime import datetime
from functools import partial

from transformer_uda.informer_models import InformerFourierPEForPrediction, MaskedInformerFourierPE
from transformer_uda.dataset_preprocess import create_train_dataloader
from transformer_uda.dataset_preprocess_raw import create_train_dataloader_raw
from transformer_uda.plotting_utils import plot_batch_examples

WANDB_DIR = "$SCRATCH/sn_transformer/wandb"
CACHE_DIR = "$SCRATCH/huggingface_datasets_cache"
CHECKPOINT_DIR = "$SCRATCH/plasticc/plasticc_all_gp_interp/models/checkpoints"

def get_dataset(data_dir, data_subset_file=None, force_redownload=False):
    kwargs = {"cache_dir": CACHE_DIR}
    if data_subset_file is not None:
        with open(data_subset_file) as f:
            data_subset = [x.strip() for x in f.readlines()]
            print(f"using data subset: {data_subset}")

            kwargs["data_files"] = {'train': data_subset}
    if force_redownload:
        kwargs["download_mode"] = "force_redownload"

    dataset = load_dataset(data_dir, **kwargs)
    # else:
    #     data_subset = [str(x) for x in Path(data_dir).glob("*.jsonl")] # this includes original training set
    print(f"loading dataset {'from file ' if data_subset_file is not None else ''}with {len(dataset['train'])} examples")

    return dataset

def save_model(model, optimizer, output_dir):
    print(f"Saving model to {output_dir}")
    if not Path(output_dir).exists():
        Path(output_dir).mkdir(parents=True)

    torch_model_dir = Path(output_dir) / "torch_model"
    hf_model_dir = Path(output_dir) / "hf_model"

    print(f"overwriting torch model at {torch_model_dir}")
    if torch_model_dir.exists():
        shutil.rmtree(torch_model_dir)
    torch_model_dir.mkdir(parents=True)

    print(f"overwriting hf model at {hf_model_dir}")
    if hf_model_dir.exists():
        shutil.rmtree(hf_model_dir)
    hf_model_dir.mkdir(parents=True)

    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, torch_model_dir / "model.pt")

    model.save_pretrained(hf_model_dir)

def prepare_model_input(batch, device, config, mask):
    model_inputs = {
            "past_time_features": batch['past_time_features'].to(device),
            "past_values": batch["past_values"].to(device),
            "past_observed_mask": batch["past_observed_mask"].to(device),
    }
    if config.num_static_categorical_features > 0:
        model_inputs["static_categorical_features"] = batch["static_categorical_features"].to(device)
    if config.num_static_real_features > 0:
        model_inputs["static_real_features"] = batch["static_real_features"].to(device)
    if not mask:
        model_inputs["future_time_features"] = batch["future_time_features"].to(device)
        model_inputs["future_observed_mask"] = batch["future_observed_mask"].to(device)
        model_inputs["future_values"] = batch["future_values"].to(device)
    else:
        model_inputs["labels"] = batch["mask_label"].to(device)

    return model_inputs

def setup_model_config(args, config):
    # model config computes certain properties, can't config.update these
    model_config = InformerConfig(
        input_size=2,
        prediction_length=0,
        context_length=300,
        lags_sequence=[0],
        num_time_features=2, #wavelength + time
        num_static_real_features=0 if not args.redshift else 1,

        # informer params:
        dropout=config['dropout_rate'],
        encoder_layers=config['num_encoder_layers'],
        decoder_layers=config['num_decoder_layers'],
        d_model=config['d_model'],
        scaling=None,
        has_labels=False,
        mask=True,
        mask_probability=args.mask_probability,
    )

    addl_config = {}
    # additional encoder/decoder hyperparams:
    if 'encoder_attention_heads' in config:
        addl_config['encoder_attention_heads'] = config['encoder_attention_heads']
    if 'decoder_attention_heads' in config:
        addl_config['decoder_attention_heads'] = config['decoder_attention_heads']
    if 'encoder_ffn_dim' in config:
        addl_config['encoder_ffn_dim'] = config['encoder_ffn_dim']
    if 'decoder_ffn_dim' in config:
        addl_config['decoder_ffn_dim'] = config['decoder_ffn_dim']
    # additional hyperparams for learnable fourier PE:
    if 'fourier_dim' in config:
        addl_config['fourier_dim'] = config['fourier_dim']
    if 'PE_hidden_dim' in config:
        addl_config['PE_hidden_dim'] = config['PE_hidden_dim']

    model_config.update(addl_config)

    return model_config

def train(args, base_config, add_config=None):
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(mixed_precision='bf16', kwargs_handlers=[ddp_kwargs])
    # accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
    # accelerator = Accelerator()

    device = accelerator.device

    if args.log_level:
        print(f"setting log level to {args.log_level}")
        log_levels = {"debug": 10, "info": 20, "warning": 30, "error": 40, "critical": 50}
        transformers.logging.set_verbosity(log_levels[args.log_level])
        datasets.logging.set_verbosity(log_levels[args.log_level])
        datasets.logging.enable_propagation()

    if not args.dry_run and accelerator.is_main_process:
        print("initializing wandb")
        wandb.init(project="informer", name=base_config['wandb_name'], config=base_config, dir=WANDB_DIR) #mode="offline")
        add_config = wandb.config
    config = base_config
    if add_config is not None:
        config.update(add_config)
    print(config)

    dataset = get_dataset(args.data_dir)
    sdss_dataset = get_dataset("$SCRATCH/sdss/dataset")
    full_sdss_dataset = concatenate_datasets([sdss_dataset['train'], sdss_dataset['validation'], sdss_dataset['test']])
    full_sdss_dataset = full_sdss_dataset.remove_columns(['label', 'redshift'])
    dataset['train'] = concatenate_datasets([dataset['train'], full_sdss_dataset])
    print(f"added SDSS data, dataset size: {len(dataset)}")

    model_config = setup_model_config(args, config)

    if args.fourier_pe and args.mask:
        print("instantiating model with fourier PE and masking")
        model = MaskedInformerFourierPE(model_config)
        dataloader_fn = create_train_dataloader_raw
    elif args.fourier_pe:
        print("instantiating model with fourier PE")
        model = InformerFourierPEForPrediction(model_config)
        dataloader_fn = create_train_dataloader_raw
    else:
        print("instantiating model with GP-interpolated inputs")
        model = InformerForPrediction(model_config)
        dataloader_fn = create_train_dataloader
    print(model)
    print(f"num total parameters: {sum(p.numel() for p in model.parameters())}")
    print(f"num trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    model.to(device)

    optimizer = AdamW(
        model.parameters(),
        lr=float(config['lr']),
        betas=(0.9, 0.95),
        weight_decay=float(config['weight_decay'])
    )

    if args.load_model:
        ckpt = torch.load(args.load_model)
        model.load_state_dict(ckpt['model_state_dict'])
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])

    train_dataloader = dataloader_fn(
        config=model_config,
        dataset=dataset['train'],
        batch_size=config['batch_size'],
    )

    model, optimizer, train_dataloader = accelerator.prepare(
        model,
        optimizer,
        train_dataloader,
    )
    if args.load_checkpoint:
        accelerator.load_state(args.load_checkpoint)

    progress_bar = tqdm(range(config['num_steps']))

    def cycle(dataloader):
        while True:
            for x in dataloader:
                yield x
            # TODO shuffle the dataloader?

    start_time = datetime.now()
    model.train()
    for idx, batch in enumerate(cycle(train_dataloader)):
        if idx == config['num_steps']:
            break
        optimizer.zero_grad()

        # with autocast(dtype=torch.bfloat16):
        outputs = model(**prepare_model_input(batch, device, model_config, args.mask))
        loss = outputs.loss

        # Backpropagation
        accelerator.backward(loss)
        optimizer.step()
        progress_bar.update(1)

        if not args.dry_run and accelerator.is_main_process:
            wandb.log({"loss": loss.item()})

        if idx % 1_000 == 0:
            print(f"step {idx}: loss = {loss.item()}")

        if idx % 5_000 == 0:
            ckpt_dir = Path(CHECKPOINT_DIR) / f"checkpoint_{config['wandb_name']}_{start_time.strftime('%Y-%m-%d_%H:%M:%S')}_step_{idx}"
            print(f"saving ckpt at {ckpt_dir}")
            accelerator.save_state(output_dir=ckpt_dir)

    if args.save_model:
        model = accelerator.unwrap_model(model)
        save_model(model, optimizer, args.save_model)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, required=True)
    parser.add_argument("--num_steps", type=int, required=True)
    parser.add_argument("--save_model", type=str)
    parser.add_argument("--load_model", type=str)
    parser.add_argument("--load_checkpoint", type=str)
    parser.add_argument("--dry_run", action="store_true")
    parser.add_argument("--fourier_pe", action="store_true")
    parser.add_argument("--mask", action="store_true")
    parser.add_argument("--redshift", action="store_true")
    parser.add_argument("--log_level", type=str)
    parser.add_argument("--lr", type=float)
    parser.add_argument("--mask_probability", default=0.6, type=float)

    args = parser.parse_args()

    with open("$HOME/time_series_transformer/transformer_uda/configs/bigger_model_hyperparameters.yml") as f:
        config = yaml.safe_load(f)
    with open("$HOME/time_series_transformer/transformer_uda/hyperparameters.yml") as f:
        sweep_config = yaml.safe_load(f)

    config['num_steps'] = args.num_steps
    config['weight_decay'] = 0.01
    config['dropout_rate'] = 0.2
    config['lr'] = 0.0001 # was 0.0001 with batch size 1024
    config['batch_size'] = 256
    config['scaling'] = None
    config["context_length"] = 170
    config["prediction_length"] = 10
    config["allow_padding"] = True
    # config["mask_probability"] = 0.5

    config['wandb_name'] = 'masked_60%'
    train(args, config)
