import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from cornet_s import CORnet_S


def get_model(map_location=None):
    model_hash = '1d3f7974'
    model = CORnet_S()
    model = torch.nn.DataParallel(model)
    url = f'https://s3.amazonaws.com/cornet-models/cornet_s-{model_hash}.pth'
    ckpt_data = torch.hub.load_state_dict_from_url(url, map_location=map_location)
    model.load_state_dict(ckpt_data['state_dict'])
    return model


class exp2_exp:
    def __init__(self, args):
        self.region_idx = int(args['region_idx'])
        self.lesion_iters = int(args['lesion_iters'])
        self.args = args
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f'using {self.device}')

        model = get_model('cuda')
        model = model.module

        for param in model.parameters():
            param.requires_grad = True

        self.model = model.to(self.device).float()
        
        self.criterion = nn.CrossEntropyLoss()

        self.optimizer = torch.optim.SGD(
            list(self.model.parameters()),
            lr=float(self.args['lr']),
            momentum=0.9,
            weight_decay=self.args['weight_decay'],
        )

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=self.optimizer,
                                                                        T_max=args['lesion_iters'])

    def run(self, train_loader, val_loader):
        metrics = {'loss': [], 'acc': []}

        layers = [self.model.V1, self.model.V2, self.model.V4, self.model.IT]
        layer = layers[self.region_idx]

        conv_layers = [module for module in layer.modules() if isinstance(module, torch.nn.Conv2d)]

        for param in self.model.parameters():
            param.requires_grad = True
            
        loss, acc1, acc5 = self.test(val_loader)
        print(f'{acc1:.4f},{acc5:.4f}')
        
        metrics['loss'].append(loss)
        metrics['top1_acc'].append(acc1)
        metrics['top5_acc'].append(acc5)

        for i in range(self.args['lesion_iters']):
            print(f'Epoch {i + 1}')

            pruned, total = 0., 0.
            for x in conv_layers:
                m = prune.random_unstructured(x, name='weight', amount=0.2)
                pruned += torch.sum(m.weight_mask == 0)
                total += torch.sum(m.weight_mask == 0) + torch.sum(m.weight_mask == 1)

            print('Pruned: ' + str(pruned.detach().cpu().item() / total))
            
            self.retrain(train_loader)
            
            loss, acc1, acc5 = self.test(val_loader)
            
            metrics['loss'].append(loss)
            metrics['top1_acc'].append(acc1)
            metrics['top5_acc'].append(acc5)

            self.scheduler.step()

        return metrics


    def retrain(self, train_loader):
        for param in self.model.parameters():
            param.requires_grad = True

        iterator = iter(train_loader)    
        for it in range(self.args['retrain_iters']):
            try:
                inputs, targets = next(iterator)
            except StopIteration:
                iterator = iter(train_loader)
                inputs, targets = next(iterator)

            inputs, targets = inputs.to(self.device), targets.to(self.device)

            self.optimizer.zero_grad()

            loss, acc1, acc5 = self.compute_metrics(inputs, targets)

            if it % 200 == 0 or it == 4095:
                print(f'\t{it + 1},{loss.item():.4f},{acc1:.4f},{acc5:.4f}')
            
            loss.backward()
            self.optimizer.step()

    def accuracy(self, output, target, topk=(1,5)):
        """Computes the accuracy over the k top predictions for the specified values of k"""
        with torch.no_grad():
            maxk = max(topk)
            batch_size = target.size(0)

            _, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(target.view(1, -1).expand_as(pred))

            res = []
            for k in topk:
                correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
                res.append(correct_k)
            return res

    def compute_metrics(self, inputs, targets):
        output = self.model(inputs)
        loss = self.criterion(output, targets)
        top1, top5 = self.accuracy(output, targets, topk=(1, 5))
        return loss, top1.item(), top5.item()

    def test(self, loader):
        self.model.eval()

        test_loss = 0.0
        test_top1 = 0.0
        test_top5 = 0.0

        with torch.no_grad():
            for inputs, targets in loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)

                loss, top1, top5 = self.compute_metrics(inputs=inputs, targets=targets,)

                test_loss += loss
                test_top1 += top1
                test_top5 += top5

        return test_loss.detach().cpu().item() / len(loader), \
               test_top1 / len(loader.dataset),\
               test_top5 / len(loader.dataset)
