import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import os
import time


class ClassifierTrainerOOD:
    def __init__(self, opt, model, dataset, dataset_out,
                 model_name, pre_model_name, optim='adam', lr_decay=0.99,
                 m_in=-25., m_out=-7.,
                 continue_train=True, print_freq=1):
        self.opt            = opt
        self.model          = model
        self.dataset        = dataset
        self.dataset_out    = dataset_out
        self.model_name     = model_name
        self.pre_model_name     = pre_model_name
        self.optim          = optim
        self.lr_decay       = lr_decay
        self.m_in           = m_in
        self.m_out          = m_out
        self.continue_train = continue_train
        self.print_freq     = print_freq

    def weight_init(self, m):
        if isinstance(m, nn.Linear):
            # nn.init.constant(m.weight, 1e-2)
            nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias,0)
        elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out")
            # nn.init.constant(m.weight, 1e-3)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 2e-1)
            nn.init.constant_(m.bias, 0)

    def train(self,
        n_epochs = 200,
        optim = 'adam',
        lr = 0.0002,
        loss_type = 'categorical',
        weight_decay=5e-4
    ):
        self.model.apply(self.weight_init)

        if self.continue_train:
            ckpt_path = '%scheckpoints/%s_state_dict'%(self.opt.data_dir, self.model_name)
            if os.path.exists(ckpt_path):
                self.model.load_state_dict(torch.load(ckpt_path))
            else: # load model not trained with OE
                pre_ckpt_path = '%scheckpoints/%s_state_dict'%(self.opt.data_dir, self.pre_model_name)
                self.model.load_state_dict(torch.load(pre_ckpt_path))

        if self.opt.use_gpu:
            self.model.cuda()

        if self.optim == 'adam':
            optimizer = torch.optim.Adam(
                self.model.parameters(),
                lr=lr,
                weight_decay=weight_decay
            )
        else:
            optimizer = torch.optim.SGD(
                self.model.parameters(),
                lr=lr,
                momentum=0.9,
                weight_decay=weight_decay
            )

        dataloader = self.dataset.train_dataloader()
        dataloader_out = self.dataset_out.train_dataloader()

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 40, gamma=0.2)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, threshold=0.005, factor=0.2)
        # scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, lr, epochs=30,
        #                                                 steps_per_epoch=len(dataloader))
        # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)

        starting_epoch_n = 0
        training_was_in_progress = False
        if self.continue_train:
            root_optimizer_ckpt_path = 'optimizer_for_%s_state_dict'%self.model_name
            optimizer_ckpt_path = root_optimizer_ckpt_path
            for filename in os.listdir('%scheckpoints'%self.opt.data_dir):
                if optimizer_ckpt_path in filename:
                    training_was_in_progress = True
                    optimizer_ckpt_path = filename

            if training_was_in_progress:
                optimizer.load_state_dict(torch.load('%scheckpoints/%s'%(self.opt.data_dir,optimizer_ckpt_path)))
                # if model exists and no optimizer ckpt is found, train from the first epoch.
                starting_epoch_n = int(optimizer_ckpt_path.split('_')[-1])


        loss_function = torch.nn.CrossEntropyLoss()
        if loss_type == 'binary':
            loss_function = torch.nn.BCELoss()
        
        has_val = 'val_dataset' in self.dataset.__dict__
        best_acc = 0.
        # begin from next epoch
        previous_loss = 1000.
        for epoch in range(starting_epoch_n+1, n_epochs+1):
            start_time = time.perf_counter()
            current_loss = 0.
            self.model.train()

            dataloader_out.dataset.offset = np.random.randint(len(dataloader_out.dataset))
            loop_train = zip(dataloader, dataloader_out)

            for iter_n, (in_set, out_set) in enumerate(loop_train):
                data = torch.cat((in_set[0], out_set[0]), 0)
                targets = in_set[1]
                if self.opt.use_gpu:
                    data = data.cuda()
                    targets = targets.cuda()
                x = self.model(data)
                
                optimizer.zero_grad()
                loss = F.cross_entropy(x[:len(in_set[0])], targets)
                ec_out = -torch.logsumexp(x[len(in_set[0]):], dim=1)
                ec_in = -torch.logsumexp(x[:len(in_set[0])], dim=1)
                loss += 0.1*(torch.pow(F.relu(ec_in-self.m_in), 2).mean() + torch.pow(F.relu(self.m_out-ec_out), 2).mean())

                loss.backward()
                optimizer.step()
            current_loss /= iter_n+1
            # if self.optim == 'sgd':
            #     if current_loss > previous_loss:
            #         lr = lr*self.lr_decay
            #         for param_group in optimizer.param_groups:
            #             param_group['lr'] = lr
            #     previous_loss = current_loss
            #     print(f'Current learning rate: {lr}')
            if loss_type == 'binary':
                acc = x[:len(in_set[0])].max(1)[1].eq(targets.max(1)[1])
            else:
                acc = x[:len(in_set[0])].max(1)[1].eq(targets)
            acc = acc.float().mean().detach().cpu()
            end_time = time.perf_counter()
            if self.opt.print_epoch and epoch%self.print_freq == 0:
                print('Epoch %d/%d | Iter %d | Acc %.5f | Loss %.5f | Time %.2fs'%
                    (epoch,n_epochs,iter_n,acc,loss,end_time-start_time))
            

            accs = self.evaluate()
            if self.continue_train:
                # torch.save(self.model.state_dict(),ckpt_path)
                # self.model.save(ckpt_path)
                # if not has_val or 'substitute' in self.model_name:
                #     torch.save(self.model.state_dict(),ckpt_path)
                #     # self.model.save(ckpt_path)
                if accs > best_acc:
                    best_acc = accs
                    torch.save(self.model.state_dict(),ckpt_path)
                    # self.model.save(ckpt_path)
                new_checkpoint_path = '%s_%d'%(root_optimizer_ckpt_path,epoch)
                torch.save(optimizer.state_dict(), '%scheckpoints/%s'%(self.opt.data_dir,new_checkpoint_path))
                if os.path.exists('%scheckpoints/%s'%(self.opt.data_dir,optimizer_ckpt_path)):
                    os.unlink('%scheckpoints/%s'%(self.opt.data_dir,optimizer_ckpt_path))
                optimizer_ckpt_path = new_checkpoint_path

            scheduler.step()
            # scheduler.step(loss)

        return self.model

    def evaluate(self):
        self.model.eval()
        accs = 0
        n_samples = 0
        dataloader = self.dataset.test_dataloader()
        for iter_n, batch in enumerate(dataloader):
            images = batch[0]
            targets = batch[1]
            n_samples += targets.shape[0]
            if self.opt.use_gpu:
                images = images.cuda()
                targets = targets.cuda()
            with torch.no_grad():
                outputs = self.model(images)
                acc = outputs.max(1)[1].eq(targets).float().sum()
                acc = acc.detach().cpu()
            accs += acc
        accs /= n_samples
        print('%s accuracy: %.5f'%(self.model_name, accs))
        return accs