import sys, os
import numpy as np
import hydra
from hydra.utils import instantiate
import wandb
from omegaconf import OmegaConf, DictConfig

import torch
from functools import reduce
from functools import partial
from dataloaders.pdebench.pdebench_loader import FNODatasetSingle 
from models.fno import FNO1d, FNO2d, FNO3d
from models.s4seq_model import OneToSeqModel, ChainModel
from utils.utilities3 import count_params, LpLoss, time_it
from timeit import default_timer
from einops import rearrange, repeat, reduce
from utils.scheduler import WarmupScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from models.prob_forecasting import get_prob_loss, ProbForecaster

from training.step import rollout_step, sequential_step, one_to_seq_step, single_step

from loss import loss_registry, SensitiveWeighted

from utils.log_utils import get_logger, add_file_handler
import logging
from utils.evaluator import evaluator_registry, DummyEvaluator, eval_to_print, eval_to_wandb, eval_to_wandb_summary, s4model_eval

from dataloaders.dataset_utils import compute_eta_by_dataset

from copy import deepcopy

log = get_logger(__name__, level = logging.INFO)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_dataset(args, batch_size):
    train_data = instantiate(args.dataset.dataloader)
    test_data = instantiate(args.dataset.dataloader, if_test=True)
    train_loader = torch.utils.data.DataLoader(
        train_data, 
        batch_size=batch_size,
        num_workers=args.num_workers, 
        shuffle=args.dataset.get("shuffle_train",True)
    )
    test_loader = torch.utils.data.DataLoader(
        test_data, 
        batch_size=batch_size,
        num_workers=args.num_workers, 
        shuffle=False)
    return train_loader, test_loader


def learning_step(step_cfg, model, batch, loss_type, is_training=True,  evaluator=DummyEvaluator, dataset = None):
    step_type = step_cfg.name
    loss = 0
    loss_fn = get_loss_value(loss_type)
    
    if len(batch) == 2:
        yy, grid = batch
        batch_dt = None
    else: 
        yy, grid, batch_dt = batch
        batch_dt = batch_dt.to(device)
    # yy: (B, Sx, [Sy], [Sz], T, V)
    batch_dt = batch_dt.to(device) if step_cfg.get("use_batch_dt", True) else None
    # xx, yy, grid = xx.to(device), yy.to(device), grid.to(device)
    yy, grid = yy.to(device), grid.to(device)
    # target is ground truth, yy will be the input of the model (might be noisy, etc...)
    target = yy

    if step_cfg.normalize_per_trajectory: 
        B, *spatial_shape, T, V = yy.shape
        normalize_dims = tuple(range(1, 1 + len(spatial_shape)) )
        mean = yy[...,0,:].mean(normalize_dims, keepdim=True).unsqueeze(-2)
        std = yy[...,0,:].std(normalize_dims, keepdim=True).unsqueeze(-2)
        yy = (yy - mean) / std
        if is_training: 
            target = (target - mean) / std


    if step_cfg.noise.use_noise:
        if step_cfg.noise.get("only_training_inputs", True):
            if is_training:
                noise = torch.randn_like(yy) * step_cfg.noise.std
                yy = yy + noise
        elif step_cfg.noise.get("only_inputs", True):
            noise = torch.randn_like(yy) * step_cfg.noise.std
            yy = yy + noise
        else:
            noise = torch.randn_like(yy) * step_cfg.noise.std
            yy = yy + noise
            target = yy
    
    loss, _target, pred = {
                    #   "rollout" : partial(rollout_step, initial_step, final_step, model, yy, target, grid, loss_fn, is_training, step_cfg),
                        "sequential":        partial(sequential_step, model, yy, target, grid, loss_fn, is_training, step_cfg, batch_dt), 
                        "sequential_markov": partial(sequential_step, model, yy, target, grid, loss_fn, is_training, step_cfg, batch_dt),
                    #   "single": partial(single_step, model, yy, target, grid, loss_fn),
                    #   "one_to_seq":  partial(one_to_seq_step, initial_step, final_step, model, yy, target, grid, loss_fn, is_training, step_cfg, batch_dt)
                    }[step_cfg.name]()
    
    if not is_training: 
        if step_cfg.scale:
            assert not step_cfg.normalize_per_trajectory, "Scaling is not supported with normalize_per_trajectory"
            # evaluate with unscaled values
            pred = dataset.unscale_data(pred)
            _target = dataset.unscale_data(_target)
        elif step_cfg.normalize_per_trajectory:
            # evaluating with unscaled values
            pred = pred * std + mean
    evaluator.evaluate(pred, _target)
    return loss, pred
        
        
def get_loss_value(loss_type, sensitive_weighted = False, num_epochs = None):
    loss_fn = loss_registry.get(loss_type)()
    if sensitive_weighted:
        loss_fn = SensitiveWeighted(loss_fn, num_epochs)
    return loss_fn


def train(step_cfg, train_loader, model, optimizer, loss_type, evaluator=DummyEvaluator):
    model.train() 
    t1 = default_timer()
    batch_timer = default_timer()
    n_train = len(train_loader.dataset)
    for _, batch in enumerate(train_loader):
        loss, _ = learning_step(
            step_cfg, model, batch,
            loss_type = loss_type,
            evaluator=evaluator,
            dataset=train_loader.dataset)
        optimizer.zero_grad()
        loss.backward()
        # gradient clipping
        if step_cfg.get("use_grad_clip", False):
            torch.nn.utils.clip_grad_norm_(model.parameters(), step_cfg.grad_clip)
        optimizer.step()
        log.debug(f"Training batch time: {default_timer() - batch_timer}")
        batch_timer = default_timer()
    t2 = default_timer()
    return model, t2-t1


@torch.no_grad()
def test(step_cfg, test_loader, model, loss_type, evaluator=DummyEvaluator):
    model.eval()
    t1 = default_timer()
    batch_timer = default_timer()
    n_test = len(test_loader.dataset)
    for batch in test_loader:
        loss, _ = learning_step(
            step_cfg, model, batch,
            loss_type= loss_type,
            evaluator=evaluator,
            is_training = False,
            dataset=test_loader.dataset)
        log.debug(f"Evaluation batch time: {default_timer() - batch_timer}")
        batch_timer = default_timer()
    t2 = default_timer() 
    return loss, t2-t1


def get_model(args, initial_step):
    model = instantiate(args)
    return model.to(device)


def initialize(args: DictConfig):
    train_loader, test_loader = get_dataset(args, batch_size=args.model.batch_size)
    # test_loader = deepcopy(train_loader)
    # test_loader.__len__ = 128
    # test_loader.dataset.u = test_loader.dataset.u[:128]
    # test_loader.dataset.x = test_loader.dataset.x[:128]
    # test_loader.dataset.t = test_loader.dataset.t[:128]

    # y grid
    _data = next(iter(test_loader))[0]
    log.info(f"Test Dataset shape: {_data.shape}")
    _data = next(iter(train_loader))[0]
    log.info(f"Train Dataset shape: {_data.shape}")
    log.info(f"Number of batches: {len(train_loader)}")
    log.info(f"Number of training steps: {args.dataset.t_train}")
    log.info(f"Number of evaluation steps: {args.dataset.t_test}")
    dimensions = len(_data.shape)
    # V == number of states (ie pressure, density, vx, vy in Navier Stokes 2D)
    T, V = test_loader.dataset.input_shape()[-2:]
    # spatial dimensions (Sx, [Sy], [Sz], T, V)
    spatial_shape = test_loader.dataset.input_shape()[:-2]
    n_dim = len(spatial_shape)

    d_output = V * args.prob_forecasting.n_params

    model =instantiate(
        args.model.params, 
        _recursive_=False,
        # initial_step = args.dataset.dataloader.initial_step, 
        spatial_shape = spatial_shape, 
        n_states = V,
        d_output = d_output,
        n_dim = n_dim,
        n_timesteps = args.step.train_timesteps,
        ).to(device)
    if args.model.get("onetoseq",False):
        model = OneToSeqModel(model)
    elif args.model.get("chain", False):
        model = ChainModel(model, args.model.chain_length)
    
    if args.prob_forecasting.use_prob_forecasting:
        model = instantiate(args.prob_forecasting.params,
                            base_model=model)
        loss_type = args.prob_forecasting.name
    else: 
        loss_type = args.loss_type

    optimizer, _ = instantiate(args.model.optimizer, 
                               model=model, 
                               epochs=args.num_epochs)
    
    # iterations = args.num_epochs * len(train_loader.dataset) // args.model.batch_size
    iterations = args.num_epochs+1
    if args.model.scheduler == 'cosine':
        eta_min_factor = args.model.get("eta_min_factor", -1)
        eta_min = args.model.optimizer.lr / eta_min_factor if eta_min_factor > 0 else 0.0
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=iterations, eta_min = eta_min)
    elif args.model.scheduler == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size = args.model.step_size, gamma=args.model.gamma)
    elif args.model.scheduler == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=10, verbose=True)
    else:
        raise NotImplementedError(f"Scheduler {args.model.scheduler} not implemented")
    
    # if scheduler is not None and args.model.warmup_epochs > 0:
    #     scheduler = WarmupScheduler(
    #         optimizer, 
    #         warmup_steps = args.model.warmu   p_epochs,
    #         base_scheduler=scheduler)

    train_evaluator = instantiate(args.step.train_evaluator)
    test_evaluator =  instantiate(args.step.test_evaluator, t_train = args.dataset.t_train, train_timestesp = args.step.train_timesteps)

    return model, train_loader, test_loader, optimizer, scheduler, train_evaluator, test_evaluator, loss_type

def main_epoch(epoch, step_cfg, model, train_loader, test_loader, train_evaluator, test_evaluator, optimizer, loss_type, scheduler, is_final, evaluate_frequency, test_eval_prefix = ""):
    do_evaluation = is_final or (epoch + 1) % evaluate_frequency == 0 or epoch == 0
    train_evaluator_ = train_evaluator if do_evaluation else DummyEvaluator()
    test_evaluator_ = test_evaluator if do_evaluation else DummyEvaluator()
    if is_final: 
        log.info("Final evaluation")
        test_loss, test_time = test(
            step_cfg, test_loader, model, loss_type, evaluator=test_evaluator_) 
        # train_time = 0.0
    else: 
        model, train_time = train(
            step_cfg, train_loader, model, optimizer, loss_type, evaluator=train_evaluator_)
        if do_evaluation:
            test_loss, test_time = test(
                step_cfg, test_loader, model, loss_type, evaluator=test_evaluator_)
        else: 
            test_time = 0.0
    
        if (isinstance(scheduler, ReduceLROnPlateau) or (isinstance(scheduler, WarmupScheduler) and isinstance(scheduler.base_scheduler, ReduceLROnPlateau))):
            if do_evaluation:
                scheduler.step(test_loss)
        else: 
            scheduler.step()
    
    # if do_evaluation:
    #     if np.isinf(test_loss.cpu().numpy()) or np.isnan(test_loss.cpu().numpy()):
    #         raise ValueError(f"Test loss is {test_loss}, terminating run")
    
    train_evaluation = train_evaluator_.get_evaluation()
    train_evaluator.reset()
    test_evaluation = test_evaluator_.get_evaluation()
    test_evaluator.reset()

    train_msg = eval_to_print(train_evaluation, is_train=True)
    test_msg = eval_to_print(test_evaluation, is_train=False)
    if not is_final:
        log.info(f"Epoch {epoch}" +
                    train_msg + 
                    test_msg +
                    " | train_time " + f"{train_time:.2f}" +
                    " | test_time " + f"{test_time:.2f}")
        log.wandb(eval_to_wandb(train_evaluation, is_train=True, prefix = "train"), step=epoch)
        log.wandb({"time/train_time": train_time, 
            "time/test_time": test_time}, 
            step=epoch)
    
    log.wandb(eval_to_wandb(test_evaluation, is_train=False, prefix = test_eval_prefix), step=epoch)
    # log learning rate
    for i,param_group in enumerate(optimizer.param_groups):
        log.wandb({f"lr/group_{i}": param_group['lr']}, step=epoch)
    
    return train_evaluation, test_evaluation

@hydra.main(config_path="configs", config_name="config", version_base=None)
def main(args: DictConfig):
    _main(args)

def _main(args: DictConfig):
    if not args.print:
        log.setLevel(logging.WARNING)
    # OmegaConf.resolve(args)
    # log.info(OmegaConf.to_yaml(args))
    # raise Exception
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.set_num_threads(args.num_threads)

    # if args.step.name == "sequential_markov" and args.step.force_markov_train_timesteps:
    #     args.step.train_timesteps = 2
    #     args.dataset.train_timesteps = 2


    if args.use_wandb:
        log.enable_wandb()
        wandb_config = OmegaConf.to_container(
            args, resolve=True, throw_on_missing=True
        )
        if args.wandb.tags is not None:
            run = wandb.init(
                project=args.wandb.project,
                group=args.wandb.group,
                name=args.wandb.name,
                config=wandb_config,
                tags=list(args.wandb.tags))
        else:
            run = wandb.init(
                project=args.wandb.project,
                group=args.wandb.group,
                name=args.wandb.name,
                config=wandb_config) 

    if args.log_model:
        model_folder_path = os.path.join(
            args.workdir.root,
            args.workdir.name,)
        os.makedirs(model_folder_path, exist_ok=True)
    else: 
        model_folder_path = None

        
    
    model, train_loader, test_loader, optimizer, scheduler, train_evaluator, test_evaluator, loss_type = initialize(args)

    total_params = count_params(model)
    total_grad_params = count_params(model, only_grad = True)
    log.info(f'Total parameters = {total_params} require_grad {total_grad_params}')
    if args.use_wandb:
        run.summary["total_parameters"] = total_params
        run.summary["total_grad_parameters"] = total_grad_params

    # sanity checks
    # if args.step.name == "sequential_markov" and args.step.train_timesteps != 2:
    #     raise ValueError(f"Sequential Markov requires train_timesteps = 2, got {args.step.train_timesteps}")

    step_final_eval = deepcopy(args.step)
    step_final_eval.update(args.step.final_eval)

    step_discard_state = deepcopy(args.step)
    step_discard_state.discard_state = True

    # eta_train = compute_eta_by_dataset(train_loader.dataset)
    # eta_test = compute_eta_by_dataset(test_loader.dataset)
   

    # log.wandb_summary( eta_train ) 
    # log.wandb_summary( eta_test )


    for epoch in range(args.num_epochs):
        train_evaluation, test_evaluation = main_epoch( epoch, args.step, model, train_loader, test_loader, train_evaluator, test_evaluator, optimizer, loss_type, scheduler, 
                                                       is_final = False, evaluate_frequency = args.evaluate_frequency, test_eval_prefix= "test")
        if args.final_eval_frequency > 0 and (epoch + 1) % args.final_eval_frequency == 0:
            train_evaluation, test_evaluation = main_epoch( epoch, step_final_eval, model, train_loader, test_loader, train_evaluator, test_evaluator, optimizer, loss_type, scheduler, 
                                                            is_final = True,  evaluate_frequency = args.evaluate_frequency, test_eval_prefix = "final_test")
            if args.step.name == "sequential":
                # evaluation with discard = True
                train_evaluation, test_evaluation = main_epoch( epoch, step_discard_state, model, train_loader, test_loader, train_evaluator, test_evaluator, optimizer, loss_type, scheduler, 
                                                         is_final = True,  evaluate_frequency = args.evaluate_frequency, test_eval_prefix = "discard_state")
                
                
                log.wandb(s4model_eval(model), step=epoch)
            
        
        if args.log_model:
            if (epoch + 1) % args.log_frequency or epoch == args.num_epochs - 1:
                checkpoint_path = os.path.join(model_folder_path, "ckpt.pt")
                torch.save(model.state_dict(), checkpoint_path)
                cfg_path = os.path.join(model_folder_path, "cfg.yaml")
                with open(cfg_path, 'w') as f:
                    f.write(OmegaConf.to_yaml(args))
    

                        
if __name__ == "__main__":
    main()

