import os
import torch
import logging
import torch.nn as nn
import torch.nn.functional as F

from torchvision.utils import save_image


# SET OPTIMIZER
from torch.optim import Adam

from src.utils.seed import set_seed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

from src.models.model_config import JCGConfig

# SET Model
from src.models.jcgel.jcgel_betavae import CGEConv_BetaVAE
# Load Model
from src.utils.load_weight import load_model

# Set datatloader
from src.dataloaders.disent.shapes3d import Shapes3D
from src.dataloaders.disent.mpi3d_toy import MPI3D_toy


# Set file dir
from src.file.file_disent import make_finetuning_files

# set info
from src.info.info_disent import write_info


# SET DATALOADER SAMPLERS
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler


# load all metrics evaluation
from disent.metrics.base import estimate_all_distenglement

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

DATASET = {'shapes3d': Shapes3D,
           'mpi3d_toy': MPI3D_toy,}

class Trainer:

    def __init__(self, args):
        set_seed(args)
        self.args = args
        self.dataset = DATASET[args.dataset]()
        self.loss_fn = nn.MSELoss(reduction="sum")

        self.args.steps = 500000
        self.args.dense_dim = [256, 256]
        self.args.n_gpu = 0 if self.args.no_cuda else torch.cuda.device_count()
        self.args.train_batch_size = self.args.per_gpu_train_batch_size * max(1, self.args.n_gpu)

        dataset_size = len(self.dataset)
        dataset_size_per_epoch = dataset_size // self.args.train_batch_size
        t_total = dataset_size_per_epoch * self.args.num_epoch if self.args.max_steps == 0 else self.args.max_steps
        self.args.num_epoch = self.args.epoch if self.args.max_steps == 0 else self.args.max_steps // dataset_size_per_epoch + 1
        self.args.t_total = t_total


        # Set model
        self.config = JCGConfig(self.args)
        self.model = CGEConv_BetaVAE(self.config)
        self.model.init_weights()

        self.save_file, self.run_file, self.output_dir = make_finetuning_files(self.args)

        # ONLY FOR EVALUATION
        if args.do_train != True:
            sub_model, path = load_model(args, self.save_file)
            if os.path.exists(path):
                self.model.load_state_dict(sub_model, strict=False)

        self.model.to(device)

        self.train_sampler, self.train_dataloader = None, None
        self.optimizer, self.scheduler = None, None
        self.global_step = 0


    def setting(self):
        set_seed(self.args)

        # set dataloader
        self.train_sampler = (
            RandomSampler(self.dataset)
            if self.args.local_rank == -1
            else DistributedSampler(self.dataset)
        )
        self.train_dataloader = DataLoader(
            self.dataset,
            sampler=self.train_sampler,
            batch_size=self.args.train_batch_size,
            drop_last=True,
            num_workers=4,
            pin_memory=True,
        )

        self.optimizer = Adam(self.model.parameters(),
                              lr=self.args.lr_rate,
                              betas=(0.9, 0.999),
                              weight_decay=self.args.weight_decay,
                              )


        self.scheduler =torch.optim.lr_scheduler.LambdaLR(self.optimizer, lambda _: 1)


        # Check if saved optimizer or scheduler states exist
        if os.path.isfile(
            os.path.join(self.save_file, "optimizer.pt")
        ) and os.path.isfile(
            os.path.join(self.save_file, "scheduler.pt")
        ):
            self.optimizer.load_state_dict(
                torch.load(
                    os.path.join(
                        self.save_file, "optimizer.pt"
                    )
                )
            )
            self.scheduler.load_state_dict(
                torch.load(
                    os.path.join(
                        self.save_file, "scheduler.pt"
                    )
                )
            )

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )





    def train(self):
        NotImplementedError


    def eval(self):
        NotImplementedError


    def qualitative(self):
        NotImplementedError


    def permutation(self):
        NotImplementedError


    def analysis(self):
        NotImplementedError


    def run(self):
        self.setting()
        if self.args.do_train:
            self.train()

        if self.args.do_eval:
            results = self.eval()
            imgs, gen_imgs = self.qualitative()

            # QUALITATIVE ANALYSIS
            imgs_dir = os.path.join(self.save_file, 'images.png')
            save_image(imgs, imgs_dir, nrow=self.args.interval + 1, pad_value=1.0)

            # p_imgs = self.permutation()
            # pimgs_dir = os.path.join(self.save_file, 'permutate_images.png')
            # save_image(p_imgs, pimgs_dir, nrow=7, pad_value=1.0)

            # QUANTITATIVE ANALYSIS
            dl_results = estimate_all_distenglement(dataset=self.dataset,
                                                    model=self.model,
                                                    loss_fn=self.loss_fn,
                                                    continuous_factors=False,
                                                    args=self.args,
                                                    results=results)


            if self.args.do_train and self.args.do_eval:
                self.args.results_file = os.path.join(self.output_dir, "results.csv")
            else:
                self.args.results_file = os.path.join(self.output_dir, "eval_only_results.csv")

            write_info(self.args, dl_results)

        return

















