import numpy as np
import torch
from torch.optim import Adam
from tqdm import tqdm
import pickle
import os
import random
from torch.utils.data import Dataset, DataLoader

class _PreBatchedDataset(Dataset):
    def __init__(self, pre_batches):
        self.pre_batches = pre_batches
    def __len__(self):
        return len(self.pre_batches)
    def __getitem__(self, idx):
        return self.pre_batches[idx]

def train(
    model,
    config,
    train_loader,
    valid_loader=None,
    valid_epoch_interval=20,
    foldername="",
    is_dp=0,
    logger = None,
):
    optimizer = Adam(model.parameters(), lr=config["lr"], weight_decay=1e-6)
    if foldername != "":
        output_path = foldername + f"/model_{model.target}.pth"

    p1 = int(0.75 * config["epochs_pr"])
    p2 = int(0.9 * config["epochs_pr"])
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[p1, p2], gamma=0.1
    )

    best_valid_mit_loss = 1e10
    for epoch_no in range(config["epochs"]):
        avg_mit_loss = 0
        avg_ort_loss = 0
        model.train()
        with tqdm(train_loader, mininterval=5.0, maxinterval=50.0) as it:
            for batch_no, train_batch in enumerate(it, start=1):
                optimizer.zero_grad()
                if model.is_ort == True:
                    loss_mit, loss_ort = model(train_batch, is_dp=is_dp)
                    loss = loss_mit+loss_ort
                    loss.backward()
                    avg_mit_loss += loss_mit.item()
                    avg_ort_loss += loss_ort.item()
                    optimizer.step()
                    it.set_postfix(
                        ordered_dict={
                            "avg_mit_loss": avg_mit_loss / batch_no,
                            "avg_ort_loss": avg_ort_loss / batch_no,
                            "epoch": epoch_no,
                        },
                        refresh=False,
                    )
                else:
                    loss_mit = model(train_batch)
                    loss_mit.backward()
                    avg_mit_loss += loss_mit.item()
                    optimizer.step()
                    it.set_postfix(
                    ordered_dict={
                        "avg_mit_loss": avg_mit_loss / batch_no,
                        "epoch": epoch_no,
                    },
                    refresh=False,
                    )
                if batch_no >= config["itr_per_epoch"]:
                    break

            lr_scheduler.step()
        if valid_loader is not None and (epoch_no + 1) % valid_epoch_interval == 0:
            model.eval()
            avg_mit_loss_valid = 0
            mse_total = 0
            mae_total = 0
            evalpoints_total = 0
            mredenom_total = 0
            with torch.no_grad():
                with tqdm(valid_loader, mininterval=5.0, maxinterval=50.0) as it:
                    for batch_no, valid_batch in enumerate(it, start=1):
                        
                        output = model.evaluate(valid_batch, 1)
                        samples, c_target, eval_points, observed_points, observed_time = output
                        samples = samples.permute(0, 1, 3, 2)  # (B,nsample,L,K)
                        c_target = c_target.permute(0, 2, 1)  # (B,L,K)
                        eval_points = eval_points.permute(0, 2, 1)
                        observed_points = observed_points.permute(0, 2, 1)

                        samples_median = samples.median(dim=1)

                        mse_current = (
                            ((samples_median.values - c_target) * eval_points) ** 2
                        )
                        mae_current = (
                            torch.abs((samples_median.values - c_target) * eval_points) 
                        )

                        mse_total += mse_current.sum().item()
                        mae_total += mae_current.sum().item()
                        evalpoints_total += eval_points.sum().item()
                        mredenom_total += (torch.abs(c_target * eval_points)).sum().item()

                        avg_mit_loss_valid += mse_total / (evalpoints_total if evalpoints_total>0 else 1)

                        it.set_postfix(
                            ordered_dict={
                                "rmse_total": np.sqrt(mse_total / (evalpoints_total if evalpoints_total>0 else 1)),
                                "mae_total": mae_total / (evalpoints_total if evalpoints_total>0 else 1),
                                "mre_total": mae_total / (mredenom_total if mredenom_total>0 else 1) * 100,
                                "batch_no": batch_no,
                            },
                            refresh=False,
                        )
            if logger is not None:
                avg_mit_loss_valid = mse_total / evalpoints_total
                logger.info(f"Validation loss {avg_mit_loss_valid} at Epoch {epoch_no+1}")

    if foldername != "":
        torch.save(model.state_dict(), output_path)


def train_pr_em(
    model,
    model_pr,
    config,
    train_loader,
    foldername="",
    logger=None,
    scale=1,
):
    model.target_strategy = "random"
    
    optimizer = Adam(model.parameters(), lr=config["lr"], weight_decay=1e-6)
    optimizer_pr = Adam(model_pr.parameters(), lr=config["lr"], weight_decay=1e-6)

    p1 = int(0.75 * config["epochs_pr"])
    p2 = int(0.9 * config["epochs_pr"])

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[p1, p2], gamma=0.1
    )
    lr_scheduler_pr = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_pr, milestones=[p1, p2], gamma=0.1
    )

    for epoch_no in range(config["epochs_pr"]):
        avg_true_loss = 0
        avg_fake_loss = 0
        avg_mit_loss = 0
        avg_ort_loss = 0
        batch_num = 0

        if logger is not None:
            logger.info(f"EM iteration {epoch_no+1} start")

        with tqdm(train_loader, mininterval=5.0, maxinterval=50.0) as it:
            for batch_no, train_batch in enumerate(it, start=1):
                optimizer.zero_grad()
                optimizer_pr.zero_grad()

                #### expectation step
                model.eval()
                model_pr.eval()
                # with torch.no_grad():
                output = model.evaluate(train_batch, 1, is_impute=True, model_pr=model_pr, scale=scale)
            
                samples, _, _, _, _ = output
                samples_median = samples.permute(0, 1, 3, 2).median(dim=1).values.detach() #(B,L,K) shape tensor
                del samples, output
                imputed_batch = dict(train_batch)
                imputed_batch['observed_data'] = samples_median

                #### maximization step
                model.train()
                model_pr.train()

                loss_true, loss_fake = model_pr(imputed_batch)
                avg_true_loss += loss_true.item()
                avg_fake_loss += loss_fake.item()

                if model.is_ort:
                    loss_mit, loss_ort = model(imputed_batch)
                    avg_mit_loss += loss_mit.item()
                    avg_ort_loss += loss_ort.item()
                    loss_total = loss_true + loss_fake + loss_mit + loss_ort
                else:
                    loss_mit = model(imputed_batch)
                    avg_mit_loss += loss_mit.item()
                    loss_total = loss_true + loss_fake + loss_mit
                
                loss_total.backward()
                optimizer.step()
                optimizer_pr.step()

                if model.is_ort:
                    it.set_postfix(
                        ordered_dict={
                            "avg_true_loss": avg_true_loss / batch_no,
                            "avg_fake_loss": avg_fake_loss / batch_no,
                            "avg_mit_loss": avg_mit_loss / batch_no,
                            "avg_ort_loss": avg_ort_loss / batch_no,
                            "epoch": epoch_no,
                        },
                        refresh=False,
                    )
                else:
                    it.set_postfix(
                        ordered_dict={
                            "avg_true_loss": avg_true_loss / batch_no,
                            "avg_fake_loss": avg_fake_loss / batch_no,
                            "avg_mit_loss": avg_mit_loss / batch_no,
                            "epoch": epoch_no,
                        },
                        refresh=False,
                    )

                if batch_no >= config["itr_per_epoch"]:
                    break
                batch_num += 1

            lr_scheduler.step()
            lr_scheduler_pr.step()
            if logger is not None:
                logger.info(f"avg_true_loss: {avg_true_loss / batch_num}, avg_fake_loss: {avg_fake_loss / batch_num}, avg_mit_loss: {avg_mit_loss / batch_num} at EM epoch {epoch_no+1}.")


        if foldername != "" and (epoch_no+1)%10==0:
            output_path_ft = foldername + f"/model_ft_em{epoch_no+1}.pth"
            output_path = foldername + f"/model_pr_em{epoch_no+1}.pth"
            torch.save(model.state_dict(), output_path_ft)
            torch.save(model_pr.state_dict(), output_path)


def quantile_loss(target, forecast, q: float, eval_points) -> float:
    return 2 * torch.sum(
        torch.abs((forecast - target) * eval_points * ((target <= forecast) * 1.0 - q))
    )


def calc_denominator(target, eval_points):
    return torch.sum(torch.abs(target * eval_points))


def calc_quantile_CRPS(target, forecast, eval_points, mean_scaler, scaler):

    target = target * scaler + mean_scaler
    forecast = forecast * scaler + mean_scaler

    quantiles = np.arange(0.05, 1.0, 0.05)
    denom = calc_denominator(target, eval_points)
    CRPS = 0
    for i in range(len(quantiles)):
        q_pred = []
        for j in range(len(forecast)):
            q_pred.append(torch.quantile(forecast[j : j + 1], quantiles[i], dim=1))
        q_pred = torch.cat(q_pred, 0)
        q_loss = quantile_loss(target, q_pred, quantiles[i], eval_points)
        CRPS += q_loss / denom
    return CRPS.item() / len(quantiles)

def calc_quantile_CRPS_sum(target, forecast, eval_points, mean_scaler, scaler):

    eval_points = eval_points.mean(-1)
    target = target * scaler + mean_scaler
    target = target.sum(-1)
    forecast = forecast * scaler + mean_scaler

    quantiles = np.arange(0.05, 1.0, 0.05)
    denom = calc_denominator(target, eval_points)
    CRPS = 0
    for i in range(len(quantiles)):
        q_pred = torch.quantile(forecast.sum(-1),quantiles[i],dim=1)
        q_loss = quantile_loss(target, q_pred, quantiles[i], eval_points)
        CRPS += q_loss / denom
    return CRPS.item() / len(quantiles)

def evaluate(model, test_loader, config, nsample=100, scaler=1, mean_scaler=0, foldername="", data_mode = 'test', target_mode='artificial', logger=None):

    with torch.no_grad():
        model.eval()
        mse_total = 0
        mae_total = 0
        evalpoints_total = 0
        mredenom_total = 0

        all_target = []
        all_observed_point = []
        all_observed_time = []
        all_evalpoint = []
        all_generated_samples = []
        with tqdm(test_loader, mininterval=5.0, maxinterval=50.0) as it:
            for batch_no, test_batch in enumerate(it, start=1):
                if target_mode == 'artificial':
                    output = model.evaluate(test_batch, nsample)
                elif target_mode =='original':
                    output = model.evaluate(test_batch, nsample, is_impute = True)

                samples, c_target, eval_points, observed_points, observed_time = output
                samples = samples.permute(0, 1, 3, 2)  # (B,nsample,L,K)
                c_target = c_target.permute(0, 2, 1)  # (B,L,K)
                eval_points = eval_points.permute(0, 2, 1)
                observed_points = observed_points.permute(0, 2, 1)

                samples_median = samples.median(dim=1)
                all_target.append(c_target)
                all_evalpoint.append(eval_points)
                all_observed_point.append(observed_points)
                all_observed_time.append(observed_time)
                all_generated_samples.append(samples)

                mse_current = (
                    ((samples_median.values - c_target) * eval_points) ** 2
                )
                mae_current = (
                    torch.abs((samples_median.values - c_target) * eval_points) 
                )

                mse_total += mse_current.sum().item()
                mae_total += mae_current.sum().item()
                evalpoints_total += eval_points.sum().item()
                mredenom_total += (torch.abs(c_target * eval_points)).sum().item()

                it.set_postfix(
                    ordered_dict={
                        "rmse_total": np.sqrt(mse_total / (evalpoints_total if evalpoints_total>0 else 1)),
                        "mae_total": mae_total / (evalpoints_total if evalpoints_total>0 else 1),
                        "mre_total": mae_total / (mredenom_total if mredenom_total>0 else 1) * 100,
                        "batch_no": batch_no,
                    },
                    refresh=True,
                )
            if target_mode == 'artificial':
                file_name = f"generated_outputs_nsample_{str(nsample)}_{data_mode}_{config['model']['test_missing_ratio']}.pk"
            elif target_mode =='original':
                file_name = f"generated_outputs_nsample_{str(nsample)}_{data_mode}.pk"
            with open(
                os.path.join(foldername, file_name), "wb"
            ) as f:
                all_target = torch.cat(all_target, dim=0)
                all_evalpoint = torch.cat(all_evalpoint, dim=0)
                all_observed_point = torch.cat(all_observed_point, dim=0)
                all_observed_time = torch.cat(all_observed_time, dim=0)
                all_generated_samples = torch.cat(all_generated_samples, dim=0)

                pickle.dump(
                    [
                        all_generated_samples,
                        all_target,
                        all_evalpoint,
                        all_observed_point,
                        all_observed_time,
                        scaler,
                        mean_scaler,
                    ],
                    f,
                )

            CRPS = calc_quantile_CRPS(
                all_target, all_generated_samples, all_evalpoint, mean_scaler, scaler
            )
            CRPS_sum = calc_quantile_CRPS_sum(
                all_target, all_generated_samples, all_evalpoint, mean_scaler, scaler
            )

            if target_mode == 'artificial':
                result_path = os.path.join(foldername, f"result_nsample_{str(nsample)}_{data_mode}_{config['model']['test_missing_ratio']}.txt")
            elif target_mode =='original':
                result_path = os.path.join(foldername, f"result_nsample_{str(nsample)}_{data_mode}.txt")    

            RMSE = np.sqrt(mse_total / evalpoints_total)
            MAE = mae_total / evalpoints_total
            MRE = mae_total / mredenom_total * 100

            with open(
                result_path, "w"
            ) as f:
                
                f.write(f"RMSE: {RMSE:.6f}\n")
                f.write(f"MAE: {MAE:.6f}\n")
                f.write(f"MRE: {MRE:.6f}\n")
                f.write(f"CRPS: {CRPS:.6f}\n")

            if logger is not None:
                logger.info(f"{target_mode} RMSE w/o Pattern Recognizer: {RMSE:.4f}")
                logger.info(f"{target_mode} MAE w/o Pattern Recognizer: {MAE:.4f}")
                logger.info(f"{target_mode} MRE w/o Pattern Recognizer: {MRE:.4f}")
                logger.info(f"{target_mode} CRPS w/o Pattern Recognizer: {CRPS:.4f}")
                logger.info(f"{target_mode} CRPS_sum w/o Pattern Recognizer: {CRPS_sum:.4f}")
            else:
                print(f"{target_mode} RMSE w/o Pattern Recognizer: {RMSE:.4f}")
                print(f"{target_mode} MAE w/o Pattern Recognizer: {MAE:.4f}")
                print(f"{target_mode} MRE w/o Pattern Recognizer: {MRE:.4f}")
                print(f"{target_mode} CRPS w/o Pattern Recognizer: {CRPS:.4f}")
                print(f"{target_mode} CRPS_sum w/o Pattern Recognizer: {CRPS_sum:.4f}")

    return RMSE, MAE, MRE

def evaluate_pr(model, model_pr, config, test_loader, nsample=10, scaler=1, mean_scaler=0, foldername="", data_mode = 'test', target_mode='artificial', scale=1, logger=None):
    model.eval()
    model_pr.eval()

    mse_total = 0
    mae_total = 0
    evalpoints_total = 0
    mredenom_total = 0

    all_target = []
    all_observed_point = []
    all_observed_time = []
    all_evalpoint = []
    all_generated_samples = []
    with tqdm(test_loader, mininterval=5.0, maxinterval=50.0) as it:
        for batch_no, test_batch in enumerate(it, start=1):
            if target_mode == 'artificial':
                output = model.evaluate(test_batch, nsample, model_pr=model_pr, scale=scale)
            elif target_mode == 'original':
                output = model.evaluate(test_batch, nsample, model_pr=model_pr, scale=scale, is_impute=True)

            samples, c_target, eval_points, observed_points, observed_time = output
            samples = samples.permute(0, 1, 3, 2)  # (B,nsample,L,K)
            c_target = c_target.permute(0, 2, 1)  # (B,L,K)
            eval_points = eval_points.permute(0, 2, 1)
            observed_points = observed_points.permute(0, 2, 1)

            samples_median = samples.median(dim=1)
            all_target.append(c_target)
            all_evalpoint.append(eval_points)
            all_observed_point.append(observed_points)
            all_observed_time.append(observed_time)
            all_generated_samples.append(samples)

            mse_current = (
                ((samples_median.values - c_target) * eval_points) ** 2
            )
            mae_current = (
                torch.abs((samples_median.values - c_target) * eval_points) 
            )

            mse_total += mse_current.sum().item()
            mae_total += mae_current.sum().item()

            evalpoints_total += eval_points.sum().item()
            mredenom_total += (torch.abs(c_target * eval_points)).sum().item()

            it.set_postfix(
                ordered_dict={
                    "rmse_total": np.sqrt(mse_total / evalpoints_total),
                    "mae_total": mae_total / evalpoints_total,
                    "mre_total": mae_total / mredenom_total * 100,
                    "batch_no": batch_no,
                },
                refresh=True,
            )
            
        if target_mode == 'artificial':
            file_name = f"generated_outputs_nsample_{str(nsample)}_{data_mode}_{config['model']['test_missing_ratio']}_pr{scale}.pk"
        elif target_mode =='original':
            file_name = f"generated_outputs_nsample_{str(nsample)}_{data_mode}_pr{scale}.pk"
        with open(
            os.path.join(foldername, file_name), "wb"
        ) as f:
            all_target = torch.cat(all_target, dim=0)
            all_evalpoint = torch.cat(all_evalpoint, dim=0)
            all_observed_point = torch.cat(all_observed_point, dim=0)
            all_observed_time = torch.cat(all_observed_time, dim=0)
            all_generated_samples = torch.cat(all_generated_samples, dim=0)

            pickle.dump(
                [
                    all_generated_samples,
                    all_target,
                    all_evalpoint,
                    all_observed_point,
                    all_observed_time,
                    scaler,
                    mean_scaler,
                ],
                f,
            )

        CRPS = calc_quantile_CRPS(
            all_target, all_generated_samples, all_evalpoint, mean_scaler, scaler
        )
        CRPS_sum = calc_quantile_CRPS_sum(
            all_target, all_generated_samples, all_evalpoint, mean_scaler, scaler
        )

        if target_mode == 'artificial':
            result_path = os.path.join(foldername, f"result_nsample_{str(nsample)}_{data_mode}_{config['model']['test_missing_ratio']}_pr.txt")
        elif target_mode =='original':
            result_path = os.path.join(foldername, f"result_nsample_{str(nsample)}_{data_mode}_pr.txt")    
        
        RMSE = np.sqrt(mse_total / evalpoints_total)
        MAE = mae_total / evalpoints_total
        MRE = mae_total / mredenom_total * 100

        with open(
            result_path, "w"
        ) as f:
            
            f.write(f"RMSE: {np.sqrt(mse_total / evalpoints_total):.6f}\n")
            f.write(f"MAE: {mae_total / evalpoints_total:.6f}\n")
            f.write(f"MRE: {mae_total / mredenom_total * 100:.6f}\n")
            f.write(f"CRPS: {CRPS:.6f}\n")

        if logger is not None:
            logger.info(f"RMSE with Pattern Recognizer: {RMSE:.4f}")
            logger.info(f"MAE with Pattern Recognizer: {MAE:.4f}")
            logger.info(f"MRE with Pattern Recognizer: {MRE:.4f}")
            logger.info(f"CRPS with Pattern Recognizer: {CRPS:.4f}")
            logger.info(f"CRPS_sum with Pattern Recognizer: {CRPS_sum:.4f}")
        else:
            print(f"RMSE with Pattern Recognizer: {RMSE:.4f}")
            print(f"MAE with Pattern Recognizer: {MAE:.4f}")
            print(f"MRE with Pattern Recognizer: {MRE:.4f}")
            print(f"CRPS with Pattern Recognizer: {CRPS:.4f}")
            print(f"CRPS_sum with Pattern Recognizer: {CRPS_sum:.4f}")

    return RMSE, MAE, MRE