import torch
import torch.nn as nn
import os
from torch.utils.data import DataLoader
from tqdm import tqdm
from robustness.datasets import ImageNet
from robustness.model_utils import make_and_restore_model
import argparse
from collections import OrderedDict
from torchvision import transforms, models
from tools.datasets import ImageNet9

device = 'cuda' if torch.cuda.is_available() else 'cpu'
_ROBUST_ROOT = '/REDACTED/dcr_models/pretrained-robust'
_IN_9_ROOT = '/REDACTED/data/bg_challenge'

def get_arch(mtype):
    if 'wide' in mtype:
        arch = models.wide_resnet50_2() if '50' in mtype else models.wide_resnet101_2()
    elif 'mobilenet' in mtype:
        arch = models.mobilenet_v2()
    elif 'shufflenet' in mtype:
        arch = models.shufflenet_v2_x1_0()
    elif 'vgg' in mtype:
        arch = models.vgg16_bn()
    elif 'densenet' in mtype:
        arch = models.densenet161()
    elif 'resnext50' in mtype:
        arch= models.resnext50_32x4d()
    return arch

class FineTuner(object):
    def __init__(self, mkey='resnet50_l2_eps3', dset='waterbirds', epochs=15, trial=None):
        # configure data
        self.dset = dset
        self.init_loaders()
        # configure model
        self.mkey = mkey
        self.init_model(mkey)
        # configure training + saving pipeline
        self.optimizer = torch.optim.Adam(self.parameters, lr=0.0005, 
                                    betas=(0.9,0.999), weight_decay=0.0001)
        self.save_path = './models_best/{}/{}{}.pth'.format(dset, mkey, '' if trial is None else f'_trial{trial}')
        self.criterion = nn.CrossEntropyLoss()
        self.best_acc = 0
        self.num_epochs = epochs

    def init_model(self, mkey):
        if 'resnet' in mkey:# and 'wide' not in mkey:
            arch = mkey.split('_')[0] if 'wide' not in mkey else 'wide_resnet50_2'
            add_custom_forward=False
        else:
            arch = get_arch(mkey[:mkey.index('_l2')])
            add_custom_forward=True
        train_ds = ImageNet('/tmp')
        net, _ = make_and_restore_model(arch=arch, dataset=train_ds, resume_path=f'{_ROBUST_ROOT}/{mkey}.ckpt', parallel=False, add_custom_forward=add_custom_forward)
        # if self.dset == 'waterbirds':
        if 'resnet' not in mkey:
            children = net.model.model.named_children()
        else:
            children = net.model.named_children()
        # feat_net = nn.Sequential(OrderedDict([('normalizer',net.normalizer), *list(net.model.model.named_children())[:-1]]))
        feat_net = nn.Sequential(OrderedDict([('normalizer',net.normalizer), *list(children)[:-1]]))
        # else:
        #     feat_net = nn.Sequential(OrderedDict([*list(net.model.named_children())[:-1]]))
        model = nn.Sequential()
        model.add_module('feat_net', feat_net)
        # in_ftrs = 512 if arch == 'resnet18' else 2048
        # out = feat_net(torch.zeros(5,3,224,224).to(device))#.shape#512 if arch == 'resnet18' else 2048
        # print(feat_net)
        # print(out.shape)
        # in_ftrs = 1
        model.add_module('flatten', nn.Flatten())
        in_ftrs = model(torch.zeros(5,3,224,224).to(device)).shape[1]#512 if arch == 'resnet18' else 2048
        model.add_module('classifier', nn.Linear(in_features=in_ftrs, out_features=self.num_classes, bias=True))
        self.gradcam_layer = model.feat_net.layer4[-1] if 'resnet' in mkey else None
    
        parameters = list(model.classifier.parameters())
        # freeze all non-final-layer parameters
        for param in model.feat_net.parameters():
            param.requires_grad = False

        self.model = model.to(device)
        self.parameters = parameters

    def init_loaders(self):
        if self.dset == 'waterbirds':
            from wilds import get_dataset
            from wilds.common.data_loaders import get_train_loader, get_eval_loader
            dataset = get_dataset("waterbirds", root_dir='/REDACTED/data')
            transform = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor()])
            train_dset, test_dset = [dataset.get_subset(s, transform=transform) for s in ['train', 'test']]
            train_loader, test_loader = [get_train_loader("standard", d, batch_size=32, num_workers=16) for d in [train_dset, test_dset]]
            self.num_classes = 2
        elif 'mixed' in self.dset: # mixed_same, mixed_rand
            in9_ds = ImageNet9('{}/{}'.format(_IN_9_ROOT, self.dset))
            test_loader = in9_ds.make_loaders(batch_size=64, workers=16, shuffle_val=True)
            train_loader = test_loader
            self.num_classes = 9

        self.loaders = dict({phase:loader for phase, loader in zip(['train', 'test'], [train_loader, test_loader])})

    def save_model(self):
        self.model.eval()
        save_dict = dict({'linear_layer': self.model.classifier.state_dict(),
                          'acc': self.best_acc})
        torch.save(save_dict, self.save_path)
        print('\nSaved model with accuracy: {:.3f} to {}\n'.format(self.best_acc, self.save_path))
        # print('\nSaved model with accuracy: {:.3f} to {}\n'.format(self.test_acc, self.save_path))

    def restore_model(self):
        print('Loading model from {}'.format(self.save_path))
        save_dict = torch.load(self.save_path)
        self.model.classifier.load_state_dict(save_dict['linear_layer'])
        self.model.eval()
        # self.test_acc = save_dict['acc']

    def gradcam_layer(self):
        return self.gradcam_layer

    def process_epoch(self, phase):
        if phase == 'train':
            self.model.train()
        else:
            self.model.eval()
        correct, running_loss, total = 0, 0, 0
        for dat in tqdm(self.loaders[phase]):
            dat = [d.cuda() for d in dat]
            x,y = dat[0], dat[1]
            self.optimizer.zero_grad()

            logits = self.model(x)
            loss = self.criterion(logits, y)
            if phase == 'train':
                loss.backward()
                self.optimizer.step()
            y_pred = logits.argmax(dim=1)
            correct += (y_pred == y).sum()
            total += x.shape[0]
            running_loss += loss.item()
        avg_loss, avg_acc = [stat/total for stat in [running_loss, correct]]
        return avg_loss, avg_acc
    
    def finetune(self):
        print('\nBeginning finetuning of model to be saved at {}\n'.format(self.save_path))
        for epoch in range(self.num_epochs):
            train_loss, train_acc = self.process_epoch('train')
            if (epoch+1) % 3 == 0:
                _, test_acc = self.process_epoch('test')
                if test_acc > self.best_acc:
                    self.best_acc = test_acc
                    self.save_model()
            print('Epoch: {}/{}......Train Loss: {:.3f}......Train Acc: {:.3f}'.format(epoch, self.num_epochs, train_loss, train_acc))
        test_loss, test_acc = self.process_epoch('test')
        print('Test Loss: {:.3f}......Test Acc: {:.3f}'.format(test_loss, test_acc))
        print('\n\nFinetuning Complete\n\n')

if __name__ == '__main__':
    # parser = argparse.ArgumentParser(description='RIVAL10 Finetuner')
    # parser.add_argument('--mkey', type=str, required=True)
    # parser.add_argument('--dset', type=str, default='waterbirds')

    # args = parser.parse_args()
    # finetuner = FineTuner(mkey=args.mkey, dset=args.dset)
    # finetuner.restore_model()
    # finetuner.finetune()

    arches = ['resnet18', 'resnet50']
    l2_epsilons = [5]#0, 0.25, 0.5, 1, 3, 5]
    linf_epsilons = []#0.5, 1.0, 2.0, 4.0, 8.0]

    # for dset in ['mixed_rand', 'mixed_same']:
    dset = 'waterbirds'
    # for trial in range(1,6):
    # for arch in ['resnet50', 'resnet18']:
    for arch in ['wide_resnet50_2']:
    # for arch in ['mobilenet', 'densenet', 'shufflenet', 'resnext50_32x4d', 'vgg16_bn']:
        for adv_train_norm, epsilons in zip(['l2', 'linf'], [l2_epsilons, linf_epsilons]):
            for adv_train_eps in epsilons:
                mkey = f'{arch}_{adv_train_norm}_eps{adv_train_eps}'
                # finetuner = FineTuner(mkey, dset, trial=trial)
                finetuner = FineTuner(mkey, dset)
                finetuner.finetune()
                
    # for mkey in tqdm(mkeys):
    #     finetuner1 = FineTuner(mkey=mkey, fg_only=False, epochs=10)
    #     finetuner1.finetune()
        # finetuner2 = FineTuner(mkey=mkey, fg_only=True, epochs=20)
        # finetuner2.finetune()
