from torch.utils.data import Dataset, DataLoader, ConcatDataset, TensorDataset
import argparse
from data.load_dataset import load_dataset
import numpy as np
import torch
import os
from global_config import ROOT_DIRECTORY
from torch.utils.data import DataLoader
from agents.ddpm_trainer import DDPMTrainer
from data.datasets.heterogeneous_2d_clusters import visualize_clusters_2d
from models.generative_models.ddpm import DDPM
from agents.model_trainer import TaskModelTrainer
from utils import fix_seed
from agents.functions import get_weighted_dataset, get_average_loss

PLOTTING = False


def parse_args():
    parser = argparse.ArgumentParser(description="Example-CSI")
    parser.add_argument('--learning_rate', type=float, default=0.1)
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--beta', type=float, default=0.9)
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--fine_tuning_epochs', type=int, default=1000)
    parser.add_argument('--dataset_name', type=str, help='dataset_name',
                        default='h2c')  
    parser.add_argument('--task_name', type=str, help='task name', default='ae')  #
    parser.add_argument('--algorithm_name', type=str, help='task name', default='is_cvar',
                        choices=['is_cvar', 'cvar', 'cvar_doro', 'erm', 'chisq'])  # is_cvar = RAMIS

    parser.add_argument('--n_samples', type=int, help='define sampling amounts after every epoch trained', default=200)
    parser.add_argument('--timesteps', type=int, help='sampling steps of DDPM', default=100)
    parser.add_argument('--device', default='cuda:0', type=str, help='device')
    parser.add_argument('--random_seed', type=int, default=123)

    parser.add_argument('--train_base_score', action='store_true')
    parser.add_argument('--train_initial_task_model', action='store_false')
    parser.add_argument('--train_task_model', action='store_false')
    args = parser.parse_args()

    return args


def main(args):
    device = args.device
    fix_seed(args.random_seed)

    # Load datasets
    train_dataloader, test_dataloader = load_dataset(
        dataset_name=args.dataset_name,
        batch_size=args.batch_size
    )
    data_shape = (2)  # (2, 2, 2)

    from models.score_functions.gaussian_mixture import GaussianMixtureDenoisingModel
    denoising_model = GaussianMixtureDenoisingModel(num_steps=args.timesteps, device=device)

    ddpm_trainer = DDPMTrainer(
        learning_rate=args.learning_rate * 0.1,
        batch_size=args.batch_size,
        training_epoch=args.epochs,
        denoising_module=DDPM(timesteps=args.timesteps, data_shape=data_shape, model=denoising_model).to(device),
        dataset_name=args.dataset_name,
        train_data_loader=train_dataloader,
        test_data_loader=test_dataloader,
        timesteps=args.timesteps,
        device=device
    )

    # Train or just load the model.
    if args.train_base_score: ddpm_trainer.train()
    # ddpm_trainer.load_model()

    base_samples = []
    for i in range(1):
        q_x_samples = ddpm_trainer.model_ema.module.sampling(args.n_samples,
                                                             clipped_reverse_diffusion=True, device=device)
        base_samples.append(q_x_samples.detach().cpu().numpy())
        print("normal sampling iter : ", i)
    base_samples = np.concatenate(base_samples, axis=0)
    np.save(os.path.join(ROOT_DIRECTORY, "results", args.dataset_name,
                         args.dataset_name + '_original_samples_ae_initial.npy'), base_samples)

    if PLOTTING: visualize_clusters_2d(data=base_samples, title="initial model sample density", density_plot=True)

    initial_train_dataloader, initial_validation_dataloader = load_dataset(
        dataset_name=args.dataset_name + '_original_samples_ae_initial.npy',
        batch_size=args.batch_size
    )
    initial_weighted_train_dataset = get_weighted_dataset(dataloader=initial_train_dataloader,
                                                          device=device, batch_size=args.batch_size,
                                                          algorithm_name=args.algorithm_name)
    initial_train_dataloader = DataLoader(initial_weighted_train_dataset, batch_size=args.batch_size, shuffle=True)

    initial_weighted_validation_dataset = get_weighted_dataset(dataloader=initial_validation_dataloader,
                                                               device=device, batch_size=args.batch_size,
                                                               algorithm_name=args.algorithm_name)
    initial_validation_dataloader = DataLoader(initial_weighted_validation_dataset, batch_size=args.batch_size,
                                               shuffle=True)

    from models.task_models.linear_regressor import PolynomialRegressor
    task_model = PolynomialRegressor(degree=2)

    class ScaledMSELoss(torch.nn.Module):
        def __init__(self, scale_factor=1.0):
            super(ScaledMSELoss, self).__init__()
            self.scale_factor = scale_factor

        def forward(self, input, target):
            return ((self.scale_factor * (input - target)) ** 2)  # .mean(dim=list(range(1, input.ndim)))
            # return torch.abs(self.scale_factor * (input - target))

    loss_function = ScaledMSELoss()

    model_name = args.algorithm_name + "_" + args.task_name + "_beta_" + str(args.beta) + "_seed_" + str(
        args.random_seed)
    task_model_trainer = TaskModelTrainer(
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        training_epoch=args.epochs,
        task_model=task_model,
        loss_function=loss_function,
        task_name=args.task_name,
        dataset_name=args.dataset_name,
        train_data_loader=initial_train_dataloader,
        test_data_loader=test_dataloader,
        device=device,
        model_name=model_name,
    )

    if args.train_initial_task_model: task_model_trainer.train('erm', initial_train_dataloader,
                                                               initial_validation_dataloader, epochs=args.epochs,
                                                               learning_rate=args.learning_rate, beta=None,
                                                               model_name="initial_" + args.task_name)
    task_model_trainer.load_model("initial_" + args.task_name)

    test_samples, initial_model_losses = task_model_trainer.test(dataloader=test_dataloader, distribution_name='p')
    test_samples = test_samples.detach().cpu().numpy()
    initial_model_losses = initial_model_losses.detach().cpu().numpy()
    model_params = task_model.get_params()
    if PLOTTING: visualize_clusters_2d(data=test_samples, losses=initial_model_losses, title="initial model",
                                       savedir=os.path.join(ROOT_DIRECTORY, "results", args.dataset_name),
                                       density_plot=False, model_params=model_params)

    class CustomLoss(torch.nn.Module):
        def __init__(self, task_model, scale_factor=1.0):
            super(CustomLoss, self).__init__()
            self.mse = torch.nn.MSELoss(reduction='none')
            # self.mse = torch.abs
            self.scale_factor = scale_factor
            self.task_model = task_model

        def forward(self, x):
            output = self.task_model(x)
            loss = torch.mean(((self.scale_factor * (output["sample"] - x)) ** 2), dim=(1)) + 1e-8
            return loss  # (loss * 10.0)**1.2

    importance_sampling_loss_function = CustomLoss(task_model=task_model)

    base_samples = []
    for i in range(1):
        q_x_samples = ddpm_trainer.model_ema.module.sampling(args.n_samples,
                                                             clipped_reverse_diffusion=True, device=device)
        base_samples.append(q_x_samples.detach().cpu().numpy())
        print("normal sampling iter : ", i)
    base_samples = np.concatenate(base_samples, axis=0)
    np.save(os.path.join(ROOT_DIRECTORY, "results", args.dataset_name, args.dataset_name + '_original_samples_ae.npy'),
            base_samples)

    if PLOTTING: visualize_clusters_2d(data=base_samples, title="sample density", density_plot=True)

    importance_samples = []
    for i in range(1):
        q_x_samples, q_x_t_list, posterior_mean_t_list = ddpm_trainer.model_ema.module.importance_sampling(
            args.n_samples, loss_function=importance_sampling_loss_function,
            clipped_reverse_diffusion=True, device=device)
        importance_samples.append(q_x_samples.detach().cpu().numpy())
        print("normal sampling iter : ", i)
    importance_samples = np.concatenate(importance_samples, axis=0)
    np.save(
        os.path.join(ROOT_DIRECTORY, "results", args.dataset_name, args.dataset_name + '_importance_samples_ae.npy'),
        importance_samples)

    if PLOTTING:
        visualize_clusters_2d(data=importance_samples, title="importance sample density", density_plot=True)

    # Load datasets
    original_train_dataloader, original_validation_dataloader = load_dataset(
        dataset_name=args.dataset_name + '_original_samples_ae.npy',
        batch_size=args.batch_size
    )

    importance_train_dataloader, importance_validation_dataloader = load_dataset(
        dataset_name=args.dataset_name + '_importance_samples_ae.npy',
        batch_size=args.batch_size
    )

    cvar_train_dataloader = importance_train_dataloader if "is_cvar" in args.algorithm_name else original_train_dataloader
    cvar_validation_dataloader = importance_validation_dataloader if "is_cvar" in args.algorithm_name else original_validation_dataloader

    E_p_l_x = get_average_loss(importance_sampling_loss_function, dataloader=initial_train_dataloader, device=device)
    importance_cvar_weighted_train_dataset = get_weighted_dataset(cvar_train_dataloader,
                                                                  importance_sampling_loss_function,
                                                                  normalizing_constant=E_p_l_x,
                                                                  device=device, batch_size=args.batch_size,
                                                                  algorithm_name=args.algorithm_name)

    importance_cvar_train_dataloader = DataLoader(
        ConcatDataset([importance_cvar_weighted_train_dataset, initial_weighted_train_dataset]),
        batch_size=args.batch_size, shuffle=True)

    importance_cvar_weighted_validation_dataset = get_weighted_dataset(cvar_validation_dataloader,
                                                                       importance_sampling_loss_function,
                                                                       normalizing_constant=E_p_l_x,
                                                                       device=device, batch_size=args.batch_size,
                                                                       algorithm_name=args.algorithm_name)
    importance_cvar_validation_dataloader = DataLoader(
        ConcatDataset([importance_cvar_weighted_validation_dataset, initial_weighted_validation_dataset]),
        batch_size=args.batch_size, shuffle=True)

    task_model_trainer.set_alpha(E_p_l_x)
    if args.train_task_model: task_model_trainer.train(args.algorithm_name, importance_cvar_train_dataloader,
                                                       importance_cvar_validation_dataloader, args.fine_tuning_epochs,
                                                       args.learning_rate, args.beta, model_name=model_name)
    task_model_trainer.load_model()

    test_samples, losses = task_model_trainer.test(dataloader=test_dataloader, distribution_name='p')
    test_samples = test_samples.detach().cpu().numpy()
    losses = losses.detach().cpu().numpy()

    if PLOTTING: visualize_clusters_2d(data=test_samples, losses=losses, title=model_name, density_plot=False,
                                       model_params=task_model.get_params())


if __name__ == "__main__":
    args = parse_args()
    main(args)
