import os
import numpy as np
import torch
import torchvision
from simclr.modules import NT_Xent


class SimCLRTrainer():
    def __init__(self, opt, model, dataset, model_name, n_epoch, simclr_batch_size,
                 print_freq=1,
                 continue_train=True):
        self.opt = opt
        self.model = model
        self.dataset = dataset
        self.model_name = model_name
        self.n_epoch = n_epoch
        self.simclr_batch_size = simclr_batch_size
        self.print_freq = print_freq
        self.continue_train = continue_train

    def train(self):
        if self.continue_train:
            # model_exists = False
            ckpt_path = '%scheckpoints/%s_state_dict'%(self.opt.data_dir, self.model_name)
            if os.path.exists(ckpt_path):
                self.model.load_state_dict(torch.load(ckpt_path))
                # model_exists = True
        if self.opt.use_gpu:
            self.model.cuda()

        training_was_in_progress = False
        if self.continue_train:
            root_optimizer_ckpt_path = 'optimizer_for_%s_state_dict'%self.model_name
            optimizer_ckpt_path = root_optimizer_ckpt_path
            for filename in os.listdir('%scheckpoints'%self.opt.data_dir):
                if optimizer_ckpt_path in filename:
                    training_was_in_progress = True
                    optimizer_ckpt_path = filename

        # if model_exists and not training_was_in_progress:
        #     print('Opimizer state lost. Model loaded.')
            # return self.model

        starting_epoch_n = 0
        if self.continue_train and training_was_in_progress:
            starting_epoch_n = int(optimizer_ckpt_path.split('_')[-1])

        train_loader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=self.simclr_batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=8,
        )
        criterion = NT_Xent(self.simclr_batch_size,
                            self.opt.temperature, self.opt.world_size)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.opt.simclr_lr,
                                     weight_decay=self.opt.simclr_weight_decay)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 40, gamma=0.2)

        # begin from next epoch
        for epoch in range(starting_epoch_n+1, self.n_epoch+1):
            loss_epoch = 0.0
            num_batch = 0.0
            for step, ((x_i, x_j), _) in enumerate(train_loader):
                num_batch += 1
                optimizer.zero_grad()
                if self.opt.use_gpu:
                    x_i = x_i.cuda()
                    x_j = x_j.cuda()

                # positive pair, with encoding
                h_i, h_j, z_i, z_j = self.model(x_i, x_j)

                loss = criterion(z_i, z_j)
                loss.backward()
                optimizer.step()
                loss_epoch += loss.item()

            if epoch % self.print_freq == 0:
                print(f"[SimCLR] epoch {epoch}/{self.n_epoch} | Loss {loss_epoch/num_batch}")

            if self.continue_train:
                torch.save(self.model.state_dict(),ckpt_path)
                new_checkpoint_path = '%s_%d'%(root_optimizer_ckpt_path,epoch)
                torch.save(optimizer.state_dict(), '%scheckpoints/%s'%(self.opt.data_dir,new_checkpoint_path))
                if os.path.exists('%scheckpoints/%s'%(self.opt.data_dir,optimizer_ckpt_path)):
                    os.unlink('%scheckpoints/%s'%(self.opt.data_dir,optimizer_ckpt_path))
                optimizer_ckpt_path = new_checkpoint_path

        # scheduler.step()

        torch.cuda.empty_cache()

        return self.model
