import pdb
import os
import logging
import torch
import torch.nn as nn


from torchvision.utils import save_image
from dl.src.trainers.utils import save_gif
from dl.src.info.utils import write_info
from dl.src.dl_metrics.base import estimate_all_distenglement

from dl.src.file.base import make_run_files
from dl.src.utils.utils import load_model, set_seed, get_constant_schedule, get_linear_schedule_with_warmup
from dl.src.constants import BASE_DATA, SEMI_DATA, Factor_DATA, DATA_HIDDEN_DIM, DATA_STEPS


# SET CONFIG
from dl.models.config.base import BetaVAEConfig, BetaTCVAEConfig, FactorVAEConfig, CLGVAEConfig
from dl.models.config.cmcs import CMCS_Config, CMCS_Unsuper_Config

# SET MODEL
from dl.models.betavae import BetaVAE
from dl.models.betatcvae import BetaTCVAE
from dl.models.clgvae import CommutativeVAE
from dl.models.factorvae import FactorVAE
from dl.models.CMCS.gt import CMCS_GT_VAE
from dl.models.CMCS.super import CMCS_Super_VAE
from dl.models.CMCS.semisuper import CMCS_SemiSuper_VAE

# SET OPTIMIZER
from torch.optim import SGD, Adam

# SET DATALOADER SAMPLERS
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

CONFIG= {
    "betavae": BetaVAEConfig,
    "factorvae": FactorVAEConfig,
    "betatcvae": BetaTCVAEConfig,
    "clgvae": CLGVAEConfig,
    "cmcs_gt": CMCS_Config,
    "cmcs_super": CMCS_Config,
    "cmcs_unsuper": CMCS_Unsuper_Config,
    "cmcs_semisuper": CMCS_Config,
}

MODELS = {
    "betavae": BetaVAE,
    "factorvae": FactorVAE,
    "betatcvae": BetaTCVAE,
    "clgvae" : CommutativeVAE,
    "cmcs_gt": CMCS_GT_VAE,
    "cmcs_super": CMCS_Super_VAE,
    "cmcs_semisuper": CMCS_SemiSuper_VAE
}

OPTIMIZER = {
    "sgd": SGD,
    "adam": Adam,
}

class CMCS_Trainer:

    def __init__(self, args):
        set_seed(args)
        if 'factor' in args.model_type:
            self.dataset = Factor_DATA[args.dataset]()
        elif 'semisuper' in args.model_type:
            self.dataset = SEMI_DATA[args.dataset]()
        else:
            self.dataset = BASE_DATA[args.dataset]()
        # self.dataset = Factor_DATA[args.dataset]() if 'factor' in args.model_type else BASE_DATA[args.dataset]()

        self.loss_fn = nn.BCELoss(reduction="sum") if args.dataset == "dsprites" or args.dataset == "cdsprites" else nn.MSELoss(reduction="sum")
        self.args = args



        # set argument
        self.args.steps = DATA_STEPS[self.args.dataset]
        self.args.dense_dim = DATA_HIDDEN_DIM[self.args.dataset]
        self.args.n_gpu = 0 if self.args.no_cuda else torch.cuda.device_count()
        self.args.train_batch_size = self.args.per_gpu_train_batch_size * max(1, self.args.n_gpu)

        dataset_size = len(self.dataset)
        dataset_size_per_epoch = dataset_size // self.args.train_batch_size
        t_total = dataset_size_per_epoch * self.args.num_epoch if self.args.max_steps == 0 else self.args.max_steps
        self.args.num_epoch = self.args.epoch if self.args.max_steps == 0 else self.args.max_steps // dataset_size_per_epoch + 1
        self.args.t_total = t_total


        # set models
        self.config = CONFIG[args.model_type](args=self.args, dataset_size=dataset_size)  \
                        if 'betatcvae' in self.args.model_type else \
                        CONFIG[args.model_type](args=self.args)

        self.model = MODELS[args.model_type](config=self.config)
        self.model.init_weights()

        self.save_file, self.run_file, self.output_dir = make_run_files(self.args)
        # ONLY FOR EVALUATION
        if args.do_train != True and args.do_eval:
            sub_model, path = load_model(args, self.save_file)
            if os.path.exists(path):
                self.model.load_state_dict(sub_model)

        self.model.to(device)


        self.train_sampler, self.train_dataloader = None, None
        self.optimizer, self.scheduler = None, None
        self.global_step = 0

    def setting(self):
        set_seed(self.args)

        # set dataloader
        self.train_sampler = (
            RandomSampler(self.dataset)
            if self.args.local_rank == -1
            else DistributedSampler(self.dataset)
        )
        self.train_dataloader = DataLoader(
            self.dataset,
            sampler=self.train_sampler,
            batch_size=self.args.train_batch_size,
            drop_last=True,
            pin_memory=True,
        )

        # self.global_step = 0

        # learning_rate = self.args.lr_rate
        # t_total = (
        #     len(self.train_dataloader) * self.args.num_epoch
        #     if self.args.max_steps == 0
        #     else self.args.max_steps
        # )
        # self.optimizer = None
        if self.args.optimizer == "adam":
            self.optimizer = OPTIMIZER[self.args.optimizer](self.model.parameters(),
                                                       lr=self.args.lr_rate,
                                                       betas=(0.9, 0.999),
                                                       weight_decay=self.args.weight_decay)


        self.scheduler = (
            get_linear_schedule_with_warmup(
                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.args.t_total
            )
            if self.args.scheduler == "linear"
            else get_constant_schedule(self.optimizer)
        )


        # Check if saved optimizer or scheduler states exist
        if os.path.isfile(
            os.path.join(self.save_file, "optimizer.pt")
        ) and os.path.isfile(
            os.path.join(self.save_file, "scheduler.pt")
        ):
            self.optimizer.load_state_dict(
                torch.load(
                    os.path.join(
                        self.save_file, "optimizer.pt"
                    )
                )
            )
            self.scheduler.load_state_dict(
                torch.load(
                    os.path.join(
                        self.save_file, "scheduler.pt"
                    )
                )
            )

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )



    def train(self):
        NotImplementedError


    def eval(self):
        NotImplementedError


    def qualitative(self):
        NotImplementedError

    def run(self):
        self.setting()
        if self.args.do_train:
            self.train()

        if self.args.do_eval:
            results = self.eval()
            imgs, gen_imgs = self.qualitative()

            # QUANTITATIVE ANALYSIS
            dl_results = estimate_all_distenglement(dataset=self.dataset,
                                                    model=self.model,
                                                    loss_fn=self.loss_fn,
                                                    continuous_factors=False,
                                                    args=self.args,
                                                    results=results)

            if self.args.do_train and self.args.do_eval:
                self.args.results_file = os.path.join(self.output_dir, "results.csv")
            else:
                self.args.results_file = os.path.join(self.output_dir, "eval_only_results.csv")

            write_info(self.args, dl_results)

            # QUALITATIVE ANALYSIS
            imgs_dir = os.path.join(self.save_file, 'images.png')
            save_image(imgs, imgs_dir, nrow=self.args.interval + 1, pad_value=1.0)
            gif_dir = os.path.join(self.save_file, 'travel.gif')
            save_gif(gen_imgs, gif_dir, args=self.args)


        return


















