import pdb
import os
import logging
import torch
import torch.nn as nn


from torchvision.utils import save_image
from cg.src.trainers.utils import save_gif
from cg.src.info.utils import write_info

from cg.src.file.base import make_run_files
from cg.src.utils.utils import load_model, set_seed, get_constant_schedule, get_linear_schedule_with_warmup
from cg.src.constants import BASE_DATA, DATA_HIDDEN_DIM, DATA_STEPS, R2E_R2R


# SET CONFIG
from cg.models.config.base import BetaVAEConfig, BetaTCVAEConfig, CLGVAEConfig
from cg.models.config.cmcs import CMCS_Config

# SET MODEL
from cg.models.betavae import BetaVAE
from cg.models.betatcvae import BetaTCVAE
from cg.models.clgvae import CommutativeVAE
from cg.models.CMCS.gt import CMCS_GT_VAE
from cg.models.CMCS.super import CMCS_Super_VAE

# SET OPTIMIZER
from torch.optim import SGD, Adam

# SET DATALOADER SAMPLERS
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

CONFIG= {
    "betavae": BetaVAEConfig,
    "betatcvae": BetaTCVAEConfig,
    "clgvae": CLGVAEConfig,
    "cmcs_gt": CMCS_Config,
    "cmcs_super": CMCS_Config
}

MODELS = {
    "betavae": BetaVAE,
    "betatcvae": BetaTCVAE,
    "clgvae" : CommutativeVAE,
    "cmcs_gt": CMCS_GT_VAE,
    "cmcs_super": CMCS_Super_VAE
}

THIS_PATH = os.path.dirname(__file__)
ROOT_PATH = os.path.abspath(os.path.join(THIS_PATH, '..', '..'))

OPTIMIZER = {
    "sgd": SGD,
    "adam": Adam,
}


class CMCS_Trainer:

    def __init__(self, args):
        set_seed(args)
        self.args = args
        dataset = BASE_DATA[self.args.dataset]()
        self.trainset, self.testset = R2E_R2R[self.args.cg][self.args.dataset](dataset)

        self.loss_fn = nn.BCELoss(reduction="sum") if self.args.dataset == "dsprites" or self.args.dataset == "cdsprites" else nn.MSELoss(reduction="sum")


        # set argument
        self.args.steps = DATA_STEPS[self.args.dataset]
        self.args.dense_dim = DATA_HIDDEN_DIM[self.args.dataset]
        self.args.n_gpu = 0 if self.args.no_cuda else torch.cuda.device_count()
        self.args.train_batch_size = self.args.per_gpu_train_batch_size * max(1, self.args.n_gpu)

        dataset_size = len(self.trainset)
        dataset_size_per_epoch = dataset_size // self.args.train_batch_size
        t_total = dataset_size_per_epoch * self.args.num_epoch if self.args.max_steps == 0 else self.args.max_steps
        self.args.num_epoch = self.args.num_epoch if self.args.max_steps == 0 else self.args.max_steps // dataset_size_per_epoch + 1
        self.args.t_total = t_total


        # set models
        self.config = CONFIG[self.args.model_type](args=self.args, dataset_size=dataset_size)  \
                        if 'betatcvae' in self.args.model_type else \
                        CONFIG[self.args.model_type](args=self.args)

        self.model = MODELS[self.args.model_type](config=self.config)
        self.model.init_weights()

        self.save_file, self.run_file, self.output_dir = make_run_files(self.args)

        # ONLY FOR EVALUATION
        if self.args.do_train != True and (self.args.do_eval or self.args.do_analysis):
            sub_model, path = load_model(self.args, self.save_file)
            if os.path.exists(path):
                self.model.load_state_dict(sub_model)

        self.model.to(device)



        self.train_sampler, self.train_dataloader = None, None
        self.test_sampler, self.test_dataloader = None, None
        self.optimizer, self.scheduler = None, None
        if self.args.model_type == "factorvae":
            self.disc_optimizer, self.disc_scheduler = None, None
        self.global_step = 0

    def setting(self):
        set_seed(self.args)

        # set dataloader
        self.train_sampler = (
            RandomSampler(self.trainset)
            if self.args.local_rank == -1
            else DistributedSampler(self.trainset)
        )
        self.train_dataloader = DataLoader(
            self.trainset,
            sampler=self.train_sampler,
            batch_size=self.args.train_batch_size,
            drop_last=True,
            pin_memory=True,
        )

        self.test_sampler = (
            SequentialSampler(self.testset)
            if self.args.local_rank == -1
            else DistributedSampler(self.testset)
        )
        self.test_dataloader = DataLoader(
            self.testset,
            sampler=self.test_sampler,
            batch_size=self.args.test_batch_size,
            drop_last=True,
            pin_memory=True,
        )

        # self.global_step = 0

        # learning_rate = self.args.lr_rate
        # t_total = (
        #     len(self.train_dataloader) * self.args.num_epoch
        #     if self.args.max_steps == 0
        #     else self.args.max_steps
        # )
        # self.optimizer = None
        if self.args.optimizer == "adam":
            self.optimizer = OPTIMIZER[self.args.optimizer](self.model.parameters(),
                                                       lr=self.args.lr_rate,
                                                       betas=(0.9, 0.999),
                                                       weight_decay=self.args.weight_decay)


        self.scheduler = (
            get_linear_schedule_with_warmup(
                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.args.t_total
            )
            if self.args.scheduler == "linear"
            else get_constant_schedule(self.optimizer)
        )

        # Check if saved optimizer or scheduler states exist
        if os.path.isfile(
            os.path.join(self.save_file, "optimizer.pt")
        ) and os.path.isfile(
            os.path.join(self.save_file, "scheduler.pt")
        ):
            self.optimizer.load_state_dict(
                torch.load(
                    os.path.join(
                        self.save_file, "optimizer.pt"
                    )
                )
            )
            self.scheduler.load_state_dict(
                torch.load(
                    os.path.join(
                        self.save_file, "scheduler.pt"
                    )
                )
            )

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )



    def train(self):
        NotImplementedError


    def eval(self):
        NotImplementedError


    def qualitative(self):
        NotImplementedError


    def analysis(self):
        NotImplementedError


    def save_results(self, results, best_imgs, worst_imgs):

        if self.args.do_train and self.args.do_eval:
            self.args.results_file = os.path.join(self.output_dir, "results.csv")
        else:
            self.args.results_file = os.path.join(self.output_dir, "eval_only_results.csv")

        write_info(self.args, results)

        # QUALITATIVE ANALYSIS
        imgs_dir = os.path.join(self.save_file, 'best_images.png')
        save_image(best_imgs, imgs_dir, nrow=10, pad_value=1.0)
        imgs_dir = os.path.join(self.save_file, 'worst_images.png')
        save_image(worst_imgs, imgs_dir, nrow=10, pad_value=1.0)

        return


    def run(self):
        self.setting()
        if self.args.do_train:
            self.train()


        if self.args.do_eval:
            # else:
            results = self.eval()
            best_imgs, worst_imgs = self.qualitative()
            self.save_results(results, best_imgs, worst_imgs)


        if self.args.do_analysis:
            self.analysis()

        return


















