import torch
import os
from utils.util import get_optimizer, get_backbone, read_yaml, load_state_dict
from dataset import data_process
import torchvision.transforms as transforms


import torch.backends.cudnn as cudnn
cudnn.benchmark = True

import warnings
warnings.filterwarnings("ignore")

class StandardTrainer():
    def __init__(self, opt):
        self.root = opt.root
        self.dataset = opt.dataset
        self.device = opt.device
        self.backbone = opt.backbone
        self.num_classes = None
        self.total_steps = None
        self.conf = read_yaml(opt.data_conf, self.backbone, self.dataset)
        self.save = opt.save

        if self.save == 'default':
            self.save = 'checkpoints/{}/{}_{}_standard.pth'.format(self.backbone, self.backbone, self.dataset)
        if opt.lr is not None:
            self.conf['lr'] = opt.lr
            print('overwrite lr: {}'.format(opt.lr))
        if opt.num_epoch is not None:
            self.conf['num_epoch'] = opt.num_epoch
            print('overwrite num_epoch: {}'.format(opt.num_epoch))

    def data_process(self):
        batch_size = self.conf['batch_size']
        train_loader, test_loader = data_process(root=self.root, dataset=self.dataset,
                                                 batch_size=batch_size, train=True)
        self.num_classes = test_loader.dataset.num_classes
        self.total_steps = self.conf['num_epoch'] * len(train_loader)
        return train_loader, test_loader

    def net_process(self):
        net = get_backbone(self.backbone)
        net = net(pretrained=False, num_classes=self.num_classes)
        return net

    def get_optimizer(self, net):
        optimizer = get_optimizer(net, self.conf['lr'],
                                  self.conf['weight_decay'], self.conf['momentum'])
        schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.total_steps, 1e-6)

        return optimizer, schedule

    def info(self):
        print('-------------standard train-------------')
        print('dataset: {}\tbackbone: {}'.format(self.dataset, self.backbone))
        print('save to: {}'.format(self.save))

    def train(self):
        self.info()
        train_loader, test_loader = self.data_process()
        train_size, test_size = len(train_loader.dataset), len(test_loader.dataset)
        net = self.net_process().to(self.device)

        optimizer, schedule = self.get_optimizer(net)

        loss_func = torch.nn.CrossEntropyLoss()
        normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        best_acc = 0

        for epoch in range(self.conf['num_epoch']):
            net.train()
            losses, corrects = 0, 0
            for images, labels in train_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                images = normalize(images)
                logits = net(images)

                loss = loss_func(logits, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                schedule.step()

                losses += loss.detach() * len(images)
                corrects += (logits.argmax(dim=1) == labels).sum().detach()

            losses, corrects = losses/train_size, corrects/train_size
            print('train epoch:{}\tloss:{:.4f}\taccuracy:{:.4f}'.format(epoch, losses, corrects))

            # simple eval
            net.eval()
            losses, corrects = 0, 0
            for images, labels in test_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                images = normalize(images)
                with torch.no_grad():
                    logits = net(images)

                loss = loss_func(logits, labels)

                losses += loss.detach() * len(images)
                corrects += (logits.argmax(dim=1) == labels).sum().detach()
            losses, corrects = losses / train_size, corrects / test_size
            print('test epoch:{}\tloss:{:.4f}\taccuracy:{:.4f}'.format(epoch, losses, corrects))

            if corrects > best_acc and self.save is not None:
                best_acc = corrects
                torch.save({'state_dict': net.state_dict(), 'epoch': epoch, 'best_acc': best_acc},
                           self.save)

            print('---------------------------')

class FinetuneTrainer(StandardTrainer):
    def __init__(self, opt):
        if opt.save == 'default':
            opt.save = 'checkpoints/{}/{}_{}_finetune.pth'.format(opt.backbone, opt.backbone, opt.dataset)
        self.load = opt.load
        super(FinetuneTrainer, self).__init__(opt)

    def net_process(self):
        net = get_backbone(self.backbone)
        if self.load == 'imagenet':
            net = net(pretrained=True)
            net.fc = torch.nn.Linear(net.fc.in_features, self.num_classes)
        else:
            net = net(pretrained=False, num_classes=self.num_classes)

            #net.load_state_dict(torch.load(self.load)['state_dict'], strict=False)
            load_state_dict(net, self.load)
        return net

    def get_optimizer(self, net):
        optimizer = get_optimizer(net, self.conf['lr'],
                                  self.conf['weight_decay'], self.conf['momentum'], freeze_level=1)
        schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.total_steps, 1e-6)
        return optimizer, schedule

    def info(self):
        print('-------------finetune train-------------')
        print('dataset: {}\tbackbone: {}'.format(self.dataset, self.backbone))
        print('load from: {}'.format(self.load))
        print('save to: {}'.format(self.save))

    def train(self):
        self.info()
        train_loader, test_loader = self.data_process()
        train_size, test_size = len(train_loader.dataset), len(test_loader.dataset)
        net = self.net_process().to(self.device)

        optimizer, schedule = self.get_optimizer(net)

        loss_func = torch.nn.CrossEntropyLoss()
        normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        best_acc = 0

        for epoch in range(self.conf['num_epoch']):
            net.train()
            losses, corrects = 0, 0
            for images, labels in train_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                images = normalize(images)
                logits = net(images)

                loss = loss_func(logits, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                #schedule.step()

                losses += loss.detach() * len(images)
                corrects += (logits.argmax(dim=1) == labels).sum().detach()

            losses, corrects = losses/train_size, corrects/train_size
            print('train epoch:{}\tloss:{:.4f}\taccuracy:{:.4f}'.format(epoch, losses, corrects))

            # simple eval
            net.eval()
            losses, corrects = 0, 0
            for images, labels in test_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                images = normalize(images)
                with torch.no_grad():
                    logits = net(images)

                loss = loss_func(logits, labels)

                losses += loss.detach() * len(images)
                corrects += (logits.argmax(dim=1) == labels).sum().detach()
            losses, corrects = losses / train_size, corrects / test_size
            print('test epoch:{}\tloss:{:.4f}\taccuracy:{:.4f}'.format(epoch, losses, corrects))

            if corrects > best_acc and self.save is not None:
                best_acc = corrects
                torch.save({'state_dict': net.state_dict(), 'epoch': epoch, 'best_acc': best_acc},
                           self.save)

            print('---------------------------')

class PartialTrainer(FinetuneTrainer):
    def __init__(self, opt):
        if opt.save == 'default':
            opt.save = 'checkpoints/{}/{}_{}_partial.pth'.format(opt.backbone, opt.backbone, opt.dataset)
        self.load = opt.load
        super(FinetuneTrainer, self).__init__(opt)

    def get_optimizer(self, net):
        optimizer = get_optimizer(net, self.conf['lr'],
                                  self.conf['weight_decay'], self.conf['momentum'], freeze_level=2)
        schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.total_steps, 1e-6)
        return optimizer, schedule

    def info(self):
        print('-------------partial train-------------')
        print('dataset: {}\tbackbone: {}'.format(self.dataset, self.backbone))
        print('load from: {}'.format(self.load))
        print('save to: {}'.format(self.save))


import argparse
if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--backbone', type=str, default='resnet18')
    parser.add_argument('--dataset', type=str, default='cifar10')
    parser.add_argument('--data_conf', type=str, default='conf.yaml')
    parser.add_argument('--root', type=str, default='datasets')
    parser.add_argument('--train', type=str, default='standard')

    parser.add_argument('--load', type=str, default='imagenet')
    parser.add_argument('--save', type=str, default='default')
    parser.add_argument('--device', type=int, default=0)

    parser.add_argument('--lr', type=float, nargs='+', default=None)
    parser.add_argument('--num_epoch', type=int, default=None)

    opt = parser.parse_args()
    if opt.train == 'standard':
        trainer = StandardTrainer(opt)
    elif opt.train == 'finetune':
        trainer = FinetuneTrainer(opt)
    elif opt.train == 'partial':
        trainer = PartialTrainer(opt)
    else:
        raise ('Not implement.')
    trainer.train()
