import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from data import simulate, simulate_complex_ge
from trainer import QuantileEstimator, QuantileRegressor
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl
import argparse
from betty.configs import Config, EngineConfig
from networks import EstimatorNet, RegressorNet, ConvRegressorNet, ConvEstimatorNet, DiscreteRegressorNet, DiscreteEstimatorNet
from trainer import MyEngine
from utils import InfiniteIterator, weights_init
from data import get_data
from utils import read_yaml_and_pass_to_argparse
from networks import UnetRegressor

parser = argparse.ArgumentParser()
parser.add_argument('--log_iter', type=int, default=0)
parser.add_argument('--num_workers', type=int, default=0)
parser.add_argument('--estimator_hidden_dim', type=int, default=32)
parser.add_argument('--estimator_n_layers', type=int, default=3)
parser.add_argument('--regressor_hidden_dim', type=int, default=32)
parser.add_argument('--regressor_n_layers', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--train_batch_size', type=int, default=32)
parser.add_argument('--test_batch_size', type=int, default=128)
parser.add_argument('--inner_iters', type=int, default=30)
parser.add_argument('--lr', type=float, default=2e-3)
parser.add_argument('--lambda_cyc', type=float, default=0)
parser.add_argument('--decay_every', type=float, default=3000)
parser.add_argument('--save_every', type=float, default=10000)
parser.add_argument('--decay_rate', type=float, default=0.5)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--exp_num', type=int, default=0)
parser.add_argument('--config', type=str, default="configs/simulation.yaml")
args = parser.parse_args()
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
read_yaml_and_pass_to_argparse(args.config, args)
tag = '_iter%d_%d_bs%s_%s_%s_nl%s_%s_dim%s_%s_lr%s_cyc%s_decay%s_%s' % (args.train_iters, args.inner_iters, args.batch_size,
                                                                        args.train_batch_size, args.test_batch_size,
                                                                        args.estimator_n_layers, args.regressor_n_layers,
                                                                        args.estimator_hidden_dim,
                                          args.regressor_hidden_dim, args.lr, args.lambda_cyc, args.decay_every, args.decay_rate)
args.run_dir += tag

pl.seed_everything(args.seed)

#================================================================
# we simulate the data
#================================================================
dataset, valid_dataset, eval_train_dataset, eval_test_dataset = get_data(args.dataset, args.exp_num)
regressor_train_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=args.num_workers, pin_memory=False)
regressor_valid_loader = DataLoader(valid_dataset, batch_size=args.train_batch_size, shuffle=True, drop_last=False, num_workers=args.num_workers, pin_memory=False)
estimator_train_loader = DataLoader(valid_dataset, batch_size=args.test_batch_size, shuffle=True, drop_last=False, num_workers=args.num_workers, pin_memory=False)
eval_train_loader = DataLoader(eval_train_dataset, batch_size=args.test_batch_size, shuffle=True, drop_last=False, num_workers=args.num_workers, pin_memory=False)
eval_test_loader = DataLoader(eval_train_dataset, batch_size=10000, shuffle=False, drop_last=False, num_workers=args.num_workers, pin_memory=False)
args.run_dir += '_%s' % len(dataset)
os.makedirs(args.run_dir, exist_ok=True)

#================================================================
# we estimate the quantile of the sample, e.g., 50% quantile
#================================================================
if args.net == 'mlp':
    estimator_net = EstimatorNet(cov_dim=args.cov_dim, treat_dim=args.treat_dim, out_dim=args.out_dim,
                             n_layers=args.estimator_n_layers, hidden_dim=args.estimator_hidden_dim).to(args.device)
    print(estimator_net)
elif args.net == 'dismlp':
    estimator_net = DiscreteEstimatorNet(cov_dim=args.cov_dim, treat_dim=args.treat_dim, out_dim=args.out_dim,
                             n_layers=args.estimator_n_layers, hidden_dim=args.estimator_hidden_dim).to(args.device)
else:
    estimator_net = ConvEstimatorNet(size=args.image_size, hidden_dim=args.estimator_hidden_dim).to(args.device)
estimator_optim = torch.optim.Adam(estimator_net.parameters(), lr=args.lr, weight_decay=args.weight_decay,
                                )
estimator_problem_config = Config(log_step=args.log_iter, first_order=True, retain_graph=True)
estimator_problem = QuantileEstimator(name='outer', module=estimator_net,
                                      optimizer=estimator_optim, train_data_loader=estimator_train_loader,
                                      config=estimator_problem_config)
estimator_problem.device = args.device
#================================================================
# we regress the quantile function given a quantile
#================================================================
if args.net == 'mlp':
    regressor_net = RegressorNet(cov_dim=args.cov_dim, treat_dim=args.treat_dim, out_dim=args.out_dim,
                             n_layers=args.regressor_n_layers, hidden_dim=args.regressor_hidden_dim).to(args.device)
    print(regressor_net)
elif args.net == 'dismlp':
    regressor_net = DiscreteRegressorNet(cov_dim=args.cov_dim, treat_dim=args.treat_dim, out_dim=args.out_dim,
                             n_layers=args.regressor_n_layers, hidden_dim=args.regressor_hidden_dim).to(args.device)
else:
    #regressor_net = UnetRegressor().to(args.device)
    regressor_net = ConvRegressorNet(size=args.image_size, hidden_dim=args.regressor_hidden_dim).to(args.device)
regressor_net.apply(weights_init('kaiming'))
regressor_problem_config = Config(type="sama", unroll_steps=args.inner_iters)
regressor_optim = torch.optim.Adam(regressor_net.parameters(), lr=args.lr,
                               weight_decay=args.weight_decay)

test_iterator = InfiniteIterator(regressor_valid_loader)
regressor_problem = QuantileRegressor(name='inner', module=regressor_net,
                                      optimizer=regressor_optim, train_data_loader=regressor_train_loader,
                                      config=regressor_problem_config)
regressor_problem.global_step = 0
regressor_problem.lambda_cyc = args.lambda_cyc
regressor_problem.device = args.device
regressor_problem.test_iterator = test_iterator
engine_config = EngineConfig(train_iters=args.train_iters, logger_type="none")


problems = [estimator_problem, regressor_problem]
u2l = {estimator_problem: [regressor_problem]}
l2u = {regressor_problem: [estimator_problem]}
dependencies = {"l2u": l2u, "u2l": u2l}


engine = MyEngine(config=engine_config, problems=problems, args=args, dependencies=dependencies)
engine.run(eval_train_loader, eval_test_loader)









