import numpy as np
import torch
from torch.optim import Adam
from tqdm import tqdm
import pickle
from ema import ExponentialMovingAverage
import os
import pathlib
import time

def get_optimizer(params):
    """Returns a flax optimizer object based on `config`."""    
    optimizer = Adam(params, lr=2e-4, betas=(0.9, 0.999), eps=1e-8,
                            weight_decay=0.)

    return optimizer

def optimization_manager():
  """Returns an optimize_fn based on `config`."""

  def optimize_fn(optimizer, params, step, lr=2e-4,
                  warmup=5000,
                  grad_clip=1.):
    """Optimizes with warmup and gradient clipping (disabled if negative)."""
    if warmup > 0:
      for g in optimizer.param_groups:
        g['lr'] = lr * np.minimum(step / warmup, 1.0)
    if grad_clip >= 0:
      torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
    optimizer.step()

  return optimize_fn

def save_checkpoint(ckpt_dir, state):
    saved_state = {
    'optimizer': state['optimizer'].state_dict(),
    'model': state['model'].state_dict(),
    'ema': state['ema'].state_dict(),
    'global_step': state['global_step']
    }
    torch.save(saved_state, ckpt_dir)
  
def restore_checkpoint(ckpt_dir, state, device):
    loaded_state = torch.load(ckpt_dir, map_location=device)
    state['optimizer'].load_state_dict(loaded_state['optimizer'])
    state['model'].load_state_dict(loaded_state['model'], strict=False)
    state['ema'].load_state_dict(loaded_state['ema'])
    state['global_step'] = loaded_state['global_step']
    return state

def train(
    model,
    config,
    train_loader,
    valid_loader=None,
    valid_epoch_interval=20,
    foldername="",
):
    ema = ExponentialMovingAverage(model.parameters(), decay=0.9999)
    optimizer = Adam(model.parameters(), lr=config["lr"], weight_decay=1e-6)
    # optimizer = get_optimizer(model.parameters())
    optimize_fn = optimization_manager()

    best_valid_loss = 1e10
    state = dict(optimizer=optimizer, model=model, ema=ema, global_step=1)
    
    for epoch_no in range(config["epochs"]):
        avg_loss = 0
        model.train()
        with tqdm(train_loader, mininterval=5.0, maxinterval=50.0) as it:
            for batch_no, train_batch in enumerate(it, start=1):
                global_step = state["global_step"]
                optimizer = state['optimizer']
                optimizer.zero_grad()
                loss = model(train_batch)
                loss.backward()
                optimize_fn(optimizer, model.parameters(), step=global_step)
                state["global_step"] += 1
                state['ema'].update(model.parameters())
                
                avg_loss += loss.item()
                it.set_postfix(
                    ordered_dict={
                        "avg_epoch_loss": avg_loss / batch_no,
                        "epoch": epoch_no,
                        "global_step": global_step,
                    },
                    refresh=False,
                )
                if batch_no >= config["itr_per_epoch"]:
                    break
            
                if global_step % 1000 == 0:
                    save_checkpoint(os.path.join(foldername, f'checkpoint_{global_step//1000}.pth'), state)

        if valid_loader is not None and (epoch_no + 1) % valid_epoch_interval == 0:
            model.eval()
            avg_loss_valid = 0
            with torch.no_grad():
                with tqdm(valid_loader, mininterval=5.0, maxinterval=50.0) as it:
                    for batch_no, valid_batch in enumerate(it, start=1):
                        ema = state['ema']
                        ema.store(model.parameters())
                        ema.copy_to(model.parameters())
                        loss = model(valid_batch, is_train=0)
                        avg_loss_valid += loss.item()
                        ema.restore(model.parameters())
                        it.set_postfix(
                            ordered_dict={
                                "valid_avg_epoch_loss": avg_loss_valid / batch_no,
                                "epoch": epoch_no,
                                "global_step": global_step,
                            },
                            refresh=False,
                        )
            if best_valid_loss > avg_loss_valid:
                best_valid_loss = avg_loss_valid
                print(
                    "\n best loss is updated to ",
                    avg_loss_valid / batch_no,
                    "at",
                    epoch_no,
                )

    # if foldername != "":
    #     torch.save(model.state_dict(), output_path)


def quantile_loss(target, forecast, q: float, eval_points) -> float:
    return 2 * torch.sum(
        torch.abs((forecast - target) * eval_points * ((target <= forecast) * 1.0 - q))
    )


def calc_denominator(target, eval_points):
    return torch.sum(torch.abs(target * eval_points))


def calc_quantile_CRPS(target, forecast, eval_points, mean_scaler, scaler):

    target = target * scaler + mean_scaler
    forecast = forecast * scaler + mean_scaler

    quantiles = np.arange(0.05, 1.0, 0.05)
    denom = calc_denominator(target, eval_points)
    CRPS = 0
    for i in range(len(quantiles)):
        q_pred = []
        for j in range(len(forecast)):
            q_pred.append(torch.quantile(forecast[j : j + 1], quantiles[i], dim=1))
        q_pred = torch.cat(q_pred, 0)
        q_loss = quantile_loss(target, q_pred, quantiles[i], eval_points)
        CRPS += q_loss / denom
    return CRPS.item() / len(quantiles)

def calc_quantile_CRPS_sum(target, forecast, eval_points, mean_scaler, scaler):

    eval_points = eval_points.mean(-1)
    target = target * scaler + mean_scaler
    target = target.sum(-1)
    forecast = forecast * scaler + mean_scaler

    quantiles = np.arange(0.05, 1.0, 0.05)
    denom = calc_denominator(target, eval_points)
    CRPS = 0
    for i in range(len(quantiles)):
        q_pred = torch.quantile(forecast.sum(-1),quantiles[i],dim=1)
        q_loss = quantile_loss(target, q_pred, quantiles[i], eval_points)
        CRPS += q_loss / denom
    return CRPS.item() / len(quantiles)

def evaluate(model, test_loader, nsample=100, scaler=1, mean_scaler=0, foldername="", args=None):

    modelfolder="./save/" + args.modelfolder + '_' + args.datatype
    ema = ExponentialMovingAverage(model.parameters(), decay=0.9999)
    optimizer = get_optimizer(model.parameters())
    state = dict(optimizer=optimizer, model=model, ema=ema, global_step=1)
    
    global_step_list = []
    global_list = list(pathlib.Path(modelfolder).glob('*.pth'))
    for global_file in global_list:
        global_step = int(str(global_file).split('/')[-1][11:-4])
        global_step_list.append(global_step)
    global_step_list.sort()
        
    if args.get_std and not args.train:
        num_test = 5
        
        if args.datatype == 'electricity':
            global_step_list = [265]
            # step 50
        elif args.datatype == 'solar':
            # global_step_list = global_step_list[4:][::5][55:][15:]
            global_step_list = [335]
            # step 200
        elif args.datatype == 'wiki':
            global_step_list = [30]
            # step 50
        elif args.datatype == 'exchange':
            global_step_list = [180]
            # step 100
        elif args.datatype == 'traffic':
            global_step_list = [155]
            # step 50
        elif args.datatype == 'taxi':
            global_step_list = [85]
            # step 50
            
            
        for global_step in global_step_list:
            
            rmse_list = []
            mae_list = []
            crps_list = []
            crps_sum_list = []
            
            print(f'Calculating {global_step}')
            ckpt_path = os.path.join(modelfolder, f'checkpoint_{global_step}.pth')
            state = restore_checkpoint(ckpt_path, state, device=args.device)  

            for _ in range(num_test):
                with torch.no_grad():
                    model.eval()
                    mse_total = 0
                    mae_total = 0
                    evalpoints_total = 0

                    all_target = []
                    all_observed_point = []
                    all_observed_time = []
                    all_evalpoint = []
                    all_generated_samples = []
                    with tqdm(test_loader, mininterval=5.0, maxinterval=50.0) as it:
                        for batch_no, test_batch in enumerate(it, start=1):
                            ema.copy_to(model.parameters())
                            start = time.time()
                            output = model.evaluate(test_batch, nsample, predictor = args.predictor, corrector = args.corrector, ode_sampling=args.ode_sampling)
                            print("Sampling finished ! Takes {} seconds ".format(time.time() - start))
                            samples, c_target, eval_points, observed_points, observed_time = output

                            samples = samples.permute(0, 1, 3, 2)  # (B,nsample,L,K)
                            c_target = c_target.permute(0, 2, 1)  # (B,L,K)
                            eval_points = eval_points.permute(0, 2, 1)
                            observed_points = observed_points.permute(0, 2, 1)

                            # # np.save('./all_samples.npy', samples.detach().cpu().numpy())
                            # for i in range(len(samples)):
                            #     print(f'{i} of samples: {samples.std(dim=1)[:,-24:,:].mean(dim=-1)[i,0].item():.5f}')

                            samples_median = samples.median(dim=1)
                            torch.set_printoptions(precision=4)
                            print(f'sample: {samples_median.values[0,-24:,0]}')
                            print(f'c_target: {c_target[0,-24:,0]}')
                            # import pdb; pdb.set_trace()
                            all_target.append(c_target)
                            all_evalpoint.append(eval_points)
                            all_observed_point.append(observed_points)
                            all_observed_time.append(observed_time)
                            all_generated_samples.append(samples)

                            mse_current = (
                                ((samples_median.values - c_target) * eval_points) ** 2
                            ) * (scaler ** 2)
                            mae_current = (
                                torch.abs((samples_median.values - c_target) * eval_points) 
                            ) * scaler

                            mse_total += mse_current.sum().item()
                            mae_total += mae_current.sum().item()
                            evalpoints_total += eval_points.sum().item()

                            it.set_postfix(
                                ordered_dict={
                                    "rmse_total": np.sqrt(mse_total / evalpoints_total),
                                    "mae_total": mae_total / evalpoints_total,
                                    "batch_no": batch_no,
                                },
                                refresh=True,
                            )

                        with open(
                            foldername + "/generated_outputs_nsample" + str(nsample) + ".pk", "wb"
                        ) as f:
                            all_target = torch.cat(all_target, dim=0)
                            all_evalpoint = torch.cat(all_evalpoint, dim=0)
                            all_observed_point = torch.cat(all_observed_point, dim=0)
                            all_observed_time = torch.cat(all_observed_time, dim=0)
                            all_generated_samples = torch.cat(all_generated_samples, dim=0)

                            pickle.dump(
                                [
                                    all_generated_samples,
                                    all_target,
                                    all_evalpoint,
                                    all_observed_point,
                                    all_observed_time,
                                    scaler,
                                    mean_scaler,
                                ],
                                f,
                            )

                        CRPS = calc_quantile_CRPS(
                            all_target, all_generated_samples, all_evalpoint, mean_scaler, scaler
                        )
                        CRPS_sum = calc_quantile_CRPS_sum(
                            all_target, all_generated_samples, all_evalpoint, mean_scaler, scaler
                        )

                        with open(
                            foldername + "/result_nsample" + str(nsample) + ".pk", "wb"
                        ) as f:
                            pickle.dump(
                                [
                                    np.sqrt(mse_total / evalpoints_total),
                                    mae_total / evalpoints_total,
                                    CRPS,
                                ],
                                f,
                            )
                            print(f'Modelfolder: {args.modelfolder}')
                            print(f'Result of {global_step}')
                            print(f'Result of Unconditional sampling')
                            # print(f'Result of TS 0.01')
                            print(f'#sample: {args.nsample}')
                            print(f'#step: {args.nstep}')
                            print(f'Datatype: {args.datatype}')
                            if args.ode_sampling:
                                print(f'Result of ODE sampling')
                            else:
                                print(f'Result of PC sampling')
                                print(f'Predictor: {args.predictor}')
                                print(f'Corrector {args.corrector}')
                            print("RMSE:", np.sqrt(mse_total / evalpoints_total))
                            print("MAE:", mae_total / evalpoints_total)
                            print("CRPS:", CRPS)
                            print("CRPS_sum:", CRPS_sum)
                            
                        rmse_list.append(np.sqrt(mse_total / evalpoints_total))
                        mae_list.append(mae_total / evalpoints_total)
                        crps_list.append(CRPS)
                        crps_sum_list.append(CRPS_sum)
                    
            print(f'Datatype: {args.datatype}')
            print(f'Result of {global_step}')
            # print(f'Result of RANDOM Substitue')
            # print(f'Result of TS Guided scale 1.')
            # print(f'Result of TRAINED')
            print(f"All RMSE: {np.array(rmse_list)}")
            print(f"All MAE: {np.array(mae_list)}")
            print(f"All CRPS: {np.array(crps_list)}")
            print(f"All CRPS_sum: {np.array(crps_sum_list)}")
            print(f"RMSE_result: {np.array(rmse_list).mean()} +/- {np.array(rmse_list).std()}")
            print(f"MAE_result: {np.array(mae_list).mean()} +/- {np.array(mae_list).std()}")
            print(f"CRPS_result: {np.array(crps_list).mean()} +/- {np.array(crps_list).std()}")
            print(f"CRPS_sum_result: {np.array(crps_sum_list).mean()} +/- {np.array(crps_sum_list).std()}")
                        
    else:
        for global_step in global_step_list[4:][::5]:
        # for global_step in torch.arange(62,68):
        # for global_step in global_step_list[34:35]:
            print(f'Calculating {global_step}')
            ckpt_path = os.path.join(modelfolder, f'checkpoint_{global_step}.pth')
            state = restore_checkpoint(ckpt_path, state, device=args.device)  

            with torch.no_grad():
                model.eval()
                mse_total = 0
                mae_total = 0
                evalpoints_total = 0

                all_target = []
                all_observed_point = []
                all_observed_time = []
                all_evalpoint = []
                all_generated_samples = []
                with tqdm(test_loader, mininterval=5.0, maxinterval=50.0) as it:
                    for batch_no, test_batch in enumerate(it, start=1):
                        ema.copy_to(model.parameters())
                        start = time.time()
                        output = model.evaluate(test_batch, nsample, predictor = args.predictor, corrector = args.corrector, ode_sampling=args.ode_sampling)
                        print("Sampling finished ! Takes {} seconds ".format(time.time() - start))
                        samples, c_target, eval_points, observed_points, observed_time = output

                        samples = samples.permute(0, 1, 3, 2)  # (B,nsample,L,K)
                        c_target = c_target.permute(0, 2, 1)  # (B,L,K)
                        eval_points = eval_points.permute(0, 2, 1)
                        observed_points = observed_points.permute(0, 2, 1)

                        # np.save('./all_samples.npy', samples.detach().cpu().numpy())
                        for i in range(len(samples)):
                            print(f'{i} of samples: {samples.std(dim=1)[:,-24:,:].mean(dim=-1)[i,0].item():.5f}')

                        samples_median = samples.median(dim=1)
                        torch.set_printoptions(precision=4)
                        print(f'sample: {samples_median.values[0,-24:,0]}')
                        print(f'c_target: {c_target[0,-24:,0]}')
                        # import pdb; pdb.set_trace()
                        all_target.append(c_target)
                        all_evalpoint.append(eval_points)
                        all_observed_point.append(observed_points)
                        all_observed_time.append(observed_time)
                        all_generated_samples.append(samples)

                        mse_current = (
                            ((samples_median.values - c_target) * eval_points) ** 2
                        ) * (scaler ** 2)
                        mae_current = (
                            torch.abs((samples_median.values - c_target) * eval_points) 
                        ) * scaler

                        mse_total += mse_current.sum().item()
                        mae_total += mae_current.sum().item()
                        evalpoints_total += eval_points.sum().item()

                        it.set_postfix(
                            ordered_dict={
                                "rmse_total": np.sqrt(mse_total / evalpoints_total),
                                "mae_total": mae_total / evalpoints_total,
                                "batch_no": batch_no,
                            },
                            refresh=True,
                        )

                    with open(
                        foldername + "/generated_outputs_nsample" + str(nsample) + ".pk", "wb"
                    ) as f:
                        all_target = torch.cat(all_target, dim=0)
                        all_evalpoint = torch.cat(all_evalpoint, dim=0)
                        all_observed_point = torch.cat(all_observed_point, dim=0)
                        all_observed_time = torch.cat(all_observed_time, dim=0)
                        all_generated_samples = torch.cat(all_generated_samples, dim=0)

                        pickle.dump(
                            [
                                all_generated_samples,
                                all_target,
                                all_evalpoint,
                                all_observed_point,
                                all_observed_time,
                                scaler,
                                mean_scaler,
                            ],
                            f,
                        )

                    CRPS = calc_quantile_CRPS(
                        all_target, all_generated_samples, all_evalpoint, mean_scaler, scaler
                    )
                    CRPS_sum = calc_quantile_CRPS_sum(
                        all_target, all_generated_samples, all_evalpoint, mean_scaler, scaler
                    )

                    with open(
                        foldername + "/result_nsample" + str(nsample) + ".pk", "wb"
                    ) as f:
                        pickle.dump(
                            [
                                np.sqrt(mse_total / evalpoints_total),
                                mae_total / evalpoints_total,
                                CRPS,
                            ],
                            f,
                        )
                        print(f'Modelfolder: {args.modelfolder}')
                        print(f'Result of {global_step}')
                        print(f'#sample: {args.nsample}')
                        print(f'#step: {args.nstep}')
                        print(f'Datatype: {args.datatype}')
                        if args.ode_sampling:
                            print(f'Result of ODE sampling')
                        else:
                            print(f'Result of PC sampling')
                            print(f'Predictor: {args.predictor}')
                            print(f'Corrector {args.corrector}')
                        print("RMSE:", np.sqrt(mse_total / evalpoints_total))
                        print("MAE:", mae_total / evalpoints_total)
                        print("CRPS:", CRPS)
                        print("CRPS_sum:", CRPS_sum)
