import torch
from torch import nn
from torch.autograd import Variable
import os
import time


class ClassifierTrainer:
    def __init__(self, opt, model, dataset, model_name, optim='adam', lr_decay=0.99,
                 continue_train=True, print_freq=1):
        self.opt        = opt
        self.model      = model
        self.dataset    = dataset
        self.model_name = model_name
        self.optim      = optim
        self.lr_decay   = lr_decay
        self.continue_train = continue_train
        self.print_freq = print_freq

    def weight_init(self, m):
        if isinstance(m, nn.Linear):
            # nn.init.constant(m.weight, 1e-2)
            nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias,0)
        elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out")
            # nn.init.constant(m.weight, 1e-3)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 2e-1)
            nn.init.constant_(m.bias, 0)

    def train(self,
        n_epochs = 200,
        optim = 'adam',
        lr = 0.0002,
        loss_type = 'categorical',
        weight_decay=5e-4
    ):
        self.model.apply(self.weight_init)

        if self.continue_train:
            model_exists = False
            ckpt_path = '%scheckpoints/%s_state_dict'%(self.opt.data_dir, self.model_name)
            if os.path.exists(ckpt_path):
                self.model.load_state_dict(torch.load(ckpt_path))
                model_exists = True

        if self.opt.use_gpu:
            self.model.cuda()

        if self.optim == 'adam':
            optimizer = torch.optim.Adam(
                self.model.parameters(),
                lr=lr,
                weight_decay=weight_decay
            )
        else:
            optimizer = torch.optim.SGD(
                self.model.parameters(),
                lr=lr,
                momentum=0.9,
                weight_decay=weight_decay
            )

        dataloader = self.dataset.train_dataloader()

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 40, gamma=0.2)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, threshold=0.005, factor=0.2)
        # scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, lr, epochs=30,
        #                                                 steps_per_epoch=len(dataloader))
        # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)

        starting_epoch_n = 0
        training_was_in_progress = False
        if self.continue_train:
            root_optimizer_ckpt_path = 'optimizer_for_%s_state_dict'%self.model_name
            optimizer_ckpt_path = root_optimizer_ckpt_path
            for filename in os.listdir('%scheckpoints'%self.opt.data_dir):
                if optimizer_ckpt_path in filename:
                    training_was_in_progress = True
                    optimizer_ckpt_path = filename

            if training_was_in_progress:
                optimizer.load_state_dict(torch.load('%scheckpoints/%s'%(self.opt.data_dir,optimizer_ckpt_path)))
                # if model exists and no optimizer ckpt is found, train from the first epoch.
                starting_epoch_n = int(optimizer_ckpt_path.split('_')[-1])

        loss_function = torch.nn.CrossEntropyLoss()
        if loss_type == 'binary':
            loss_function = torch.nn.BCELoss()
        
        has_val = 'val_dataset' in self.dataset.__dict__
        best_acc = 0.
        # begin from next epoch
        previous_loss = 1000.
        for epoch in range(starting_epoch_n+1, n_epochs+1):
            start_time = time.perf_counter()
            current_loss = 0.
            self.model.train()

            for iter_n, batch in enumerate(dataloader):
                images = Variable(batch[0], requires_grad=True)
                targets = Variable(batch[1])
                if self.opt.use_gpu:
                    images = images.cuda()
                    targets = targets.cuda()
                outputs = self.model(images)
                if loss_type == 'binary':
                    outputs = torch.softmax(outputs, dim=-1)
                loss = loss_function(outputs, targets)
                current_loss += loss.item()
                loss.backward()
                # clip_gradient(optimizer, 0.1)
                optimizer.step()
                optimizer.zero_grad()
            current_loss /= iter_n+1
            # if self.optim == 'sgd':
            #     if current_loss > previous_loss:
            #         lr = lr*self.lr_decay
            #         for param_group in optimizer.param_groups:
            #             param_group['lr'] = lr
            #     previous_loss = current_loss
            #     print(f'Current learning rate: {lr}')
            if loss_type == 'binary':
                acc = outputs.max(1)[1].eq(targets.max(1)[1])
            else:
                acc = outputs.max(1)[1].eq(targets)
            acc = acc.float().mean().detach().cpu()
            end_time = time.perf_counter()
            if self.opt.print_epoch and epoch%self.print_freq == 0:
                print('Epoch %d/%d | Iter %d | Acc %.5f | Loss %.5f | Time %.2fs'%
                    (epoch,n_epochs,iter_n,acc,loss,end_time-start_time))
            

            accs = self.evaluate()
            if self.continue_train:
                # torch.save(self.model.state_dict(),ckpt_path)
                # self.model.save(ckpt_path)
                # if not has_val or 'substitute' in self.model_name:
                #     torch.save(self.model.state_dict(),ckpt_path)
                #     # self.model.save(ckpt_path)
                if accs > best_acc:
                    best_acc = accs
                    torch.save(self.model.state_dict(),ckpt_path)
                    # self.model.save(ckpt_path)
                new_checkpoint_path = '%s_%d'%(root_optimizer_ckpt_path,epoch)
                torch.save(optimizer.state_dict(), '%scheckpoints/%s'%(self.opt.data_dir,new_checkpoint_path))
                if os.path.exists('%scheckpoints/%s'%(self.opt.data_dir,optimizer_ckpt_path)):
                    os.unlink('%scheckpoints/%s'%(self.opt.data_dir,optimizer_ckpt_path))
                optimizer_ckpt_path = new_checkpoint_path

            scheduler.step()
            # scheduler.step(loss)

        return self.model

    def evaluate(self):
        self.model.eval()
        accs = 0
        n_samples = 0
        dataloader = self.dataset.test_dataloader()
        for iter_n, batch in enumerate(dataloader):
            images = batch[0]
            targets = batch[1]
            n_samples += targets.shape[0]
            if self.opt.use_gpu:
                images = images.cuda()
                targets = targets.cuda()
            with torch.no_grad():
                outputs = self.model(images)
                acc = outputs.max(1)[1].eq(targets).float().sum()
                acc = acc.detach().cpu()
            accs += acc
        accs /= n_samples
        print('%s accuracy: %.5f'%(self.model_name, accs))
        return accs