import os
import time
from datetime import datetime
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import math

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from absl import app, flags
from torch.multiprocessing import Process

from get_data import get_dataset
from models.wrapper import get_model
from algorithms.wrapper import get_algorithm
from utils import Logger, backup_code, check_args, get_optimizer, share_params

import torch.optim as optim

from algorithms.utils import compute_abs_mean, compute_eucdist, compute_cosdist

FLAGS = flags.FLAGS

# Training
flags.DEFINE_integer("train_steps", 1000, "Total training steps for a single run")
flags.DEFINE_integer("episode_train_steps", 100, "# of training steps for a single task")
flags.DEFINE_integer("batch_size", 100, "Batch size")

flags.DEFINE_enum("adapt_opt", "sgd", ["adam", "sgd", "rmsprop"], "optimizer")
flags.DEFINE_float("adapt_inner_lr", 1e-2, "Learning rate")
flags.DEFINE_float("adapt_outer_lr", 1, "Learning rate")
flags.DEFINE_float("adapt_momentum", 0.9, "Momentum")
flags.DEFINE_float("adapt_weight_decay", 0, "Weight decay")
flags.DEFINE_bool("adapt_nesterov_test", False, "Nesterov")

flags.DEFINE_enum("hyper_opt", "sgd", ["adam", "sgd", "rmsprop"], "meta optimizer")
flags.DEFINE_float("hyper_lr", 1e-3, "Meta learning rate")
flags.DEFINE_float("hyper_momentum", 0.9, "Momentum")
flags.DEFINE_float("hyper_weight_decay", 0, "Weight decay")
flags.DEFINE_bool("hyper_nesterov", False, "Nesterov")

# Meta learning algorithm
flags.DEFINE_string("algorithm", "ours", "algorithm")
flags.DEFINE_integer("regress_every", 50, "linear regression period")
flags.DEFINE_integer("num_neumann_steps", 5, "neumann")
flags.DEFINE_float("gamma", 0.9, "gamma")
flags.DEFINE_bool("lr_decay", False, "lr_decay")
flags.DEFINE_integer("neumann_factor", 1, "neumann_factor")

# Model
flags.DEFINE_string("model", "anil", "Model")
flags.DEFINE_integer("conv_channels", 32, "Model")

# Data
flags.DEFINE_string("data", "cifar", "Data")
flags.DEFINE_integer("num_classes", 10, "# of classes per task")
flags.DEFINE_integer("img_size", 32, "Image size")

# Misc
flags.DEFINE_string("tblog_dir", None, "Directory for tensorboard logs")
flags.DEFINE_string("code_dir", "/st1/hblee/longmeta/image_exp/codes", "Directory for backup code")
flags.DEFINE_string("save_dir", "/st1/hblee/longmeta/image_exp/checkpoints", "Directory for checkpoints")
flags.DEFINE_string("exp_name", "", "Experiment name")
flags.DEFINE_integer("print_every", 1, "Print period")
flags.DEFINE_integer("test_every", 50, "Print period")
flags.DEFINE_integer("save_every", 1000, "Save period")
flags.DEFINE_list("gpus", "", "GPUs to use")
flags.DEFINE_string("port", "123456", "Port number for multiprocessing")
flags.DEFINE_integer("num_workers", 1, "The number of workers for dataloading")

#def run_single_process(rank, backend="nccl"):
def run_single_process(rank=0, backend="nccl"):

    dist.init_process_group(backend, rank=rank, world_size=1)
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpus[rank]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    meta_train_ds = get_dataset(FLAGS.data, "meta_train", FLAGS.img_size)
    meta_val_ds = get_dataset(FLAGS.data, "meta_val", FLAGS.img_size)
    meta_test_ds = get_dataset(FLAGS.data, "meta_test", FLAGS.img_size)

    # Models and Algorithms
    model = get_model(FLAGS, track_bn_stats=True).to(device)
    meta_train = get_algorithm(FLAGS.algorithm)
    meta_test = get_algorithm("meta_test")

    criterion = nn.CrossEntropyLoss().to(device)

    hyper_opt = get_optimizer(
        FLAGS.hyper_opt,
        model.get_hyper_params(),
        FLAGS.hyper_lr,
        momentum=FLAGS.hyper_momentum,
        weight_decay=FLAGS.hyper_weight_decay,
        nesterov=FLAGS.hyper_nesterov,
    )

    # Logger
    logger = Logger(
        exp_name=FLAGS.exp_name,
        log_dir=FLAGS.log_dir,
        save_dir=FLAGS.save_dir,
        exp_suffix=f"src/split_{rank+1}",
        print_every=FLAGS.print_every,
        save_every=FLAGS.save_every,
        total_step=FLAGS.train_steps,
        print_to_stdout=(rank == 0),
        use_wandb=True,
        wnadb_project_name="longmeta",
        wandb_tags=[f"split_{rank+1}"],
        wandb_config=FLAGS,
    )

    if rank == 0:
        logger.register_object_to_save(model.state_dict(), "model")

    def linear_regression(dictll, folder_name, j, rank, device, old_coeff=None):

        from sklearn.linear_model import LinearRegression
        gamma = FLAGS.gamma

        inputs = dictll["X"]
        outputs = dictll["y"]

        inputs = inputs.reshape([-1,1])
        outputs = outputs.reshape([-1])

        results = LinearRegression(fit_intercept=False, positive=True).fit(inputs, outputs)
        beta = results.coef_[0]
        print("beta", beta)

        coeff = [torch.Tensor([beta]).to(device)]
        share_params(coeff)
        to_return = coeff[0].cpu().detach().numpy()

        return to_return

    coeff_running = 0
    j = 0
    folder_name = os.path.join('figure', FLAGS.exp_name)
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

    if "ours" in FLAGS.algorithm and FLAGS.model != "mwn":
        dictll = meta_train.pre_run(
                rank, model, meta_train_ds, hyper_opt, device, criterion, logger, FLAGS)
        coeff = linear_regression(dictll, folder_name, j, rank, device)
        coeff_running = coeff
        j = j + 1
        logger.meter("meta_train", "beta", coeff[0])
        logger.meter("meta_train", "beta_running", coeff_running[0])

    logger.start()

    for _ in range(10):
        meta_test.run(rank, model, meta_train_ds, device, criterion, logger, FLAGS, "meta_train")
        meta_test.run(rank, model, meta_val_ds, device, criterion, logger, FLAGS, "meta_val")

    for i in range(1, FLAGS.train_steps + 1):

        if FLAGS.algorithm == "ours":
            meta_train.run(rank, model, meta_train_ds, hyper_opt, device, criterion, logger, FLAGS,
                    coeff_running, i)
        else:
            meta_train.run(rank, model, meta_train_ds, hyper_opt, device, criterion, logger, FLAGS,
                    i)

        if i % FLAGS.test_every == 0:
            for _ in range(10):
                meta_test.run(rank, model, meta_train_ds, device, criterion, logger, FLAGS, "meta_train")
                meta_test.run(rank, model, meta_val_ds, device, criterion, logger, FLAGS, "meta_val")

        if i % FLAGS.regress_every == 0:
            if FLAGS.algorithm == "ours" and FLAGS.model != "mwn":
                dictll = meta_train.pre_run(
                        rank, model, meta_train_ds, hyper_opt, device, criterion, logger, FLAGS)
                coeff = linear_regression(dictll, folder_name, j, rank, device,
                        old_coeff=coeff_running)
                coeff_running = 0.5*coeff_running + (1-0.5)*coeff
                if rank == 0:
                    print("coeff_running = ", coeff_running)
                j = j + 1
                logger.meter("meta_train", "beta", coeff[0])
                logger.meter("meta_train", "beta_running", coeff_running[0])

        if i == FLAGS.train_steps:
            for _ in range(125):
                meta_test.run(rank, model, meta_test_ds, device, criterion, logger, FLAGS, "meta_test")

        logger.step()

    logger.finish()

def run_multi_process(argv):
    del argv
    check_args(FLAGS)

    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = FLAGS.port
    os.environ["WANDB_SILENT"] = "true"
    processes = []

    for rank in range(len(FLAGS.gpus)):
        p = Process(target=run_single_process, args=(rank,))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

if __name__ == "__main__":
    app.run(run_multi_process)
