import pdb
import os
import logging
import torch
import torch.nn as nn


from torchvision.utils import save_image
from dl.src.trainers.utils import save_gif
from dl.src.info.utils import write_info
from dl.src.dl_metrics.base import estimate_all_distenglement

from dl.src.file.base import make_run_files
from dl.src.utils.utils import load_model, set_seed, get_constant_schedule, get_linear_schedule_with_warmup
from dl.src.constants import BASE_DATA, Factor_DATA, DATA_HIDDEN_DIM, DATA_STEPS


# SET CONFIG
from dl.models.config.base import BetaVAEConfig, BetaTCVAEConfig, FactorVAEConfig, CLGVAEConfig
from dl.models.config.cmcs import CMCS_Config

# SET MODEL
from dl.models.betavae import BetaVAE
from dl.models.betatcvae import BetaTCVAE
from dl.models.clgvae import CommutativeVAE
from dl.models.factorvae import FactorVAE
from dl.models.CMCS.gt import CMCS_GT_VAE
from dl.models.CMCS.super import CMCS_Super_VAE

# SET OPTIMIZER
from torch.optim import SGD, Adam

# SET DATALOADER SAMPLERS
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

CONFIG= {
    "betavae": BetaVAEConfig,
    "factorvae": FactorVAEConfig,
    "betatcvae": BetaTCVAEConfig,
    "clgvae": CLGVAEConfig,
    "cmcs_gt": CMCS_Config,
    "cmcs_super": CMCS_Config
}

MODELS = {
    "betavae": BetaVAE,
    "factorvae": FactorVAE,
    "betatcvae": BetaTCVAE,
    "clgvae" : CommutativeVAE,
    "cmcs_gt": CMCS_GT_VAE,
    "cmcs_super": CMCS_Super_VAE
}

OPTIMIZER = {
    "sgd": SGD,
    "adam": Adam,
}

class Trainer:

    def __init__(self, args):
        set_seed(args)
        self.dataset = Factor_DATA[args.dataset]() if 'factor' in args.model_type else BASE_DATA[args.dataset]()
        self.loss_fn = nn.BCELoss(reduction="sum") if args.dataset == "dsprites" or args.dataset == "cdsprites" or args.dataset == "mmnist" else nn.MSELoss(reduction="sum")
        self.args = args



        # set argument
        self.args.steps = DATA_STEPS[self.args.dataset]
        self.args.dense_dim = DATA_HIDDEN_DIM[self.args.dataset]
        self.args.n_gpu = 0 if self.args.no_cuda else torch.cuda.device_count()
        self.args.train_batch_size = self.args.per_gpu_train_batch_size * max(1, self.args.n_gpu)

        dataset_size = len(self.dataset)
        dataset_size_per_epoch = dataset_size // self.args.train_batch_size
        t_total = dataset_size_per_epoch * self.args.num_epoch if self.args.max_steps == 0 else self.args.max_steps
        self.args.num_epoch = self.args.epoch if self.args.max_steps == 0 else self.args.max_steps // dataset_size_per_epoch + 1
        self.args.t_total = t_total


        # set models
        self.config = CONFIG[args.model_type](args=self.args, dataset_size=dataset_size)  \
                        if 'betatcvae' in self.args.model_type else \
                        CONFIG[args.model_type](args=self.args)

        self.model = MODELS[args.model_type](config=self.config)
        self.model.init_weights()

        self.save_file, self.run_file, self.output_dir = make_run_files(self.args)
        # ONLY FOR EVALUATION
        if args.do_train != True and args.do_eval:
            sub_model, path = load_model(args, self.save_file)
            if os.path.exists(path):
                self.model.load_state_dict(sub_model)

        self.model.to(device)


        self.train_sampler, self.train_dataloader = None, None
        self.optimizer, self.scheduler = None, None
        if self.args.model_type == "factorvae":
            self.disc_optimizer, self.disc_scheduler = None, None
        self.global_step = 0

    def setting(self):
        set_seed(self.args)

        # set dataloader
        self.train_sampler = (
            RandomSampler(self.dataset)
            if self.args.local_rank == -1
            else DistributedSampler(self.dataset)
        )
        self.train_dataloader = DataLoader(
            self.dataset,
            sampler=self.train_sampler,
            batch_size=self.args.train_batch_size,
            drop_last=True,
            pin_memory=True,
        )

        # self.global_step = 0

        # learning_rate = self.args.lr_rate
        # t_total = (
        #     len(self.train_dataloader) * self.args.num_epoch
        #     if self.args.max_steps == 0
        #     else self.args.max_steps
        # )
        # self.optimizer = None
        if self.args.optimizer == "adam" and self.args.model_type != "factorvae":
            self.optimizer = OPTIMIZER[self.args.optimizer](self.model.parameters(),
                                                       lr=self.args.lr_rate,
                                                       betas=(0.9, 0.999),
                                                       weight_decay=self.args.weight_decay)
        elif self.args.optimizer == "adam" and self.args.model_type == "factorvae":
            main_list, sub_list = [],  []
            for n, p in self.model.named_parameters():
                if "discriminator" in n and self.args.model_type == 'factorvae':
                    sub_list.append(p)
                else:
                    main_list.append(p)

            self.disc_optimizer = OPTIMIZER[self.args.optimizer](
                [{"params": sub_list}],  # {"params": sub_list, "lr": 2e-5}],
                lr=self.args.lr_rate_disc,
                betas=(0.5, 0.9),
                weight_decay=self.args.weight_decay,
            )

            self.optimizer = OPTIMIZER[self.args.optimizer](
                [{"params": main_list}],  # {"params": sub_list, "lr": 2e-5}],
                lr=self.args.lr_rate,
                betas=(0.9, 0.999),
                weight_decay=self.args.weight_decay,
            )


        self.scheduler = (
            get_linear_schedule_with_warmup(
                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.args.t_total
            )
            if self.args.scheduler == "linear"
            else get_constant_schedule(self.optimizer)
        )

        if self.args.model_type == "factorvae":
            self.disc_scheduler = (
                get_linear_schedule_with_warmup(
                    self.disc_optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.t_total
                )
                if self.args.scheduler == "linear"
                else get_constant_schedule(self.disc_optimizer)
            )

        # Check if saved optimizer or scheduler states exist
        if os.path.isfile(
            os.path.join(self.save_file, "optimizer.pt")
        ) and os.path.isfile(
            os.path.join(self.save_file, "scheduler.pt")
        ):
            self.optimizer.load_state_dict(
                torch.load(
                    os.path.join(
                        self.save_file, "optimizer.pt"
                    )
                )
            )
            self.scheduler.load_state_dict(
                torch.load(
                    os.path.join(
                        self.save_file, "scheduler.pt"
                    )
                )
            )
        if self.args.model_type == "factorvae":
            if os.path.isfile(
                os.path.join(self.save_file, "disc_optimizer.pt")
            ) and os.path.isfile(
                os.path.join(self.save_file, "disc_scheduler.pt")
            ):
                self.optimizer.load_state_dict(
                    torch.load(
                        os.path.join(
                            self.save_file, "disc_optimizer.pt"
                        )
                    )
                )
                self.scheduler.load_state_dict(
                    torch.load(
                        os.path.join(
                            self.save_file, "disc_scheduler.pt"
                        )
                    )
                )

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )



    def train(self):
        NotImplementedError


    def eval(self):
        NotImplementedError


    def qualitative(self):
        NotImplementedError


    def permutation(self):
        NotImplementedError



    def run(self):
        self.setting()
        if self.args.do_train:
            self.train()

        if self.args.do_eval:
            results = self.eval()
            imgs, gen_imgs = self.qualitative()

            # QUALITATIVE ANALYSIS
            imgs_dir = os.path.join(self.save_file, 'images.png')
            save_image(imgs, imgs_dir, nrow=self.args.interval + 1, pad_value=1.0)
            gif_dir = os.path.join(self.save_file, 'travel.gif')
            save_gif(gen_imgs, gif_dir, args=self.args)


            # QUANTITATIVE ANALYSIS
            dl_results = estimate_all_distenglement(dataset=self.dataset,
                                                    model=self.model,
                                                    loss_fn=self.loss_fn,
                                                    continuous_factors=False,
                                                    args=self.args,
                                                    results=results)


            if self.args.do_train and self.args.do_eval:
                self.args.results_file = os.path.join(self.output_dir, "results.csv")
            else:
                self.args.results_file = os.path.join(self.output_dir, "eval_only_results.csv")

            write_info(self.args, dl_results)

        return


















