import os
import argparse
import json
import sys
import time
import torch
import torchvision
from pathlib import Path
from torch import nn, optim
from torchvision import datasets, transforms
from trainer import ESCNNEncoder, PredictorEqv, ProjectionMLP, Identity
from escnn import nn as enn
from escnn import gspaces
from datasets import *
from tqdm import tqdm

parser = argparse.ArgumentParser(description='Evaluate resnet50 features on ImageNet')
parser.add_argument('--data', default='/data1/home/data/stl10',type=Path, metavar='DIR',
                    help='path to dataset')
parser.add_argument('--pretrained-dir', default='./experiments/stl10_training_alpha_0.0_beta_0.1',
                    type=Path, metavar='FILE',
                    help='path to pretrained model')
parser.add_argument('--weights', default='freeze', type=str,
                    choices=('finetune', 'freeze'),
                    help='finetune or freeze resnet weights')
parser.add_argument('--workers', default=8, type=int, metavar='N',
                    help='number of data loader workers')
parser.add_argument('--gpus', default='0', type=str)
parser.add_argument('--batch-size', default=256, type=int, metavar='N',
                    help='mini-batch size')
parser.add_argument('--arch', default='escnn18', type=str, help='model architecture',
                    choices=['resnet18', 'resnet50', 'escnn18', 'escnn50'])
parser.add_argument('--num-classes', default=196, type=int)
parser.add_argument('--pretrain-set', default='stl10', type=str, help='pretrain dataset',
                    choices=['stl10','imagenet100','caltech256'])
parser.add_argument('--eval-set', default='stl10', type=str, help='evaluation dataset',
                    choices=['stl10','imagenet100','stanford_cars','fgvc_aircraft','cub_200_2011','cifar10','cifar100','caltech256'])
parser.add_argument('--connector', default='softmax', type=str, help='equivariance connection map',
                    choices=['softmax', 'identity', 'tanh', 'shift'])
parser.add_argument('--use_gpool', action='store_true')
parser.add_argument('--rotated', action='store_true')
parser.add_argument('--inv_linear_probe', action='store_true')
parser.add_argument('--fc_checkpoint', default = 'final_fc_hx.pth', type=str)
parser.add_argument('--iterations', default=0, type=int)


def main():
    args = parser.parse_args()
    args.ngpus_per_node = torch.cuda.device_count()
    args.gpus = [int(x) for x in args.gpus.split(',')]
    main_worker(args)

def main_worker(args):
    if args.rotated:        
        if not args.inv_linear_probe:
            statfile_name = args.eval_set + f'_rotation_stats_linear_probe_rotated_{args.iterations}.txt'
        else:
            statfile_name = args.eval_set + f'_rotation_stats_linear_probe_rotated_inv_{args.iterations}.txt'
    else:
        if not args.inv_linear_probe:
            statfile_name = args.eval_set + f'_rotation_stats_linear_probe_{args.iterations}.txt'
        else:
            statfile_name = args.eval_set + f'_rotation_stats_linear_probe_inv_{args.iterations}.txt'

    stats_file = open(args.pretrained_dir / statfile_name, 'a', buffering=1)
    torch.backends.cudnn.benchmark = True
    
    val_datasets, num_classes = load_eval_rotated_sets(args)
    args.num_classes = num_classes
    
    kwargs = dict(batch_size=args.batch_size, num_workers=args.workers, pin_memory=True)
    
    val_loaders = []
    for i in range(len(val_datasets)):
        val_loaders.append(torch.utils.data.DataLoader(val_datasets[i], shuffle = True, **kwargs))
    
    model = Encoder(args)
    model.cuda()
    
    state_dict = torch.load(os.path.join(args.pretrained_dir, 'final.pth'), map_location='cpu')
    missing_keys, unexpected_keys = model.backbone.load_state_dict(state_dict['backbone'], strict=False)
    print(missing_keys)
    print(unexpected_keys)
    
    if args.inv_linear_probe:

        missing_keys, unexpected_keys = model.predictor_eqv.load_state_dict(state_dict['predictor_eqv'], strict=False)
        print(missing_keys)
        print(unexpected_keys)
    
    state_dict2 = torch.load(os.path.join(args.pretrained_dir, args.fc_checkpoint), map_location='cpu')
    missing_keys, unexpected_keys = model.fc.load_state_dict(state_dict2['fc'], strict=False)
    print(missing_keys)
    print(unexpected_keys)
    
    if args.weights == 'freeze':
        model.backbone.requires_grad_(False)
        model.fc.requires_grad_(False)
        if args.inv_linear_probe:
            model.predictor_eqv.requires_grad_(False)
    else:
        assert False
        
    for i in tqdm(range(len(val_loaders))):
        # evaluate
        model.eval()
        best_acc = argparse.Namespace(top1=0, top5=0)
        top1 = AverageMeter('Acc@1')
        top5 = AverageMeter('Acc@5')
        with torch.no_grad():
            for images, target in val_loaders[i]:
                output = model(images.cuda(non_blocking=True))
                acc1, acc5 = accuracy(output, target.cuda(non_blocking=True), topk=(1,5))
                top1.update(acc1[0].item(), images.size(0))
                top5.update(acc5[0].item(), images.size(0))
        best_acc.top1 = max(best_acc.top1, top1.avg)
        best_acc.top5 = max(best_acc.top5, top5.avg)
        stats = dict(degree=10*i, acc1=top1.avg, acc5=top5.avg, best_acc1=best_acc.top1, best_acc5=best_acc.top5)
        print(json.dumps(stats))
        print(json.dumps(stats), file=stats_file)

class Encoder(nn.Module):

    def __init__(self, args):
        super().__init__()
        self.inv_linear_probe = False
        if args.arch=='escnn18' or args.arch=='escnn50':
            self.backbone = ESCNNEncoder(args.arch, use_gpool=args.use_gpool)
            self.order = self.backbone.order

            if args.use_gpool:
                feature_dim = self.backbone.num_out_trivial_repr
            else:
                num_out_regular_repr = self.backbone.num_out_regular_repr   # resnet18, resnet34
                feature_dim = num_out_regular_repr * self.order

                if args.inv_linear_probe:
                    self.inv_linear_probe = args.inv_linear_probe
                    self.gspace = gspaces.rot2dOnR2(N=self.order)
                    self.in_type = enn.FieldType(self.gspace, num_out_regular_repr*[self.gspace.regular_repr])
                    hidden_type = enn.FieldType(self.gspace, 512*[self.gspace.regular_repr])
                    out_type_eqv = enn.FieldType(self.gspace, [self.gspace.regular_repr])
                    self.predictor_eqv = PredictorEqv(self.in_type, hidden_type, out_type_eqv)

                    if args.connector=='softmax':
                        self.connector = torch.nn.Softmax(dim=1)
                    elif args.connector=='identity':
                        self.connector = Identity()
                    elif args.connector=='tanh':
                        self.connector = torch.nn.Tanh()
                    elif args.connector=='shift':
                        self.connector = None
                        permute_patterns = [torch.roll(torch.arange(self.order), shifts=-i).tolist() for i in range(self.order)]
                        self.permute_tensor = torch.tensor(permute_patterns).cuda()

        elif args.arch=='resnet18':
            self.backbone = torchvision.models.resnet18(zero_init_residual=True)
            self.backbone.fc = nn.Identity()
            feature_dim = 512
        elif args.arch=='resnet50':
            self.backbone = torchvision.models.resnet50(zero_init_residual=True)
            self.backbone.fc = nn.Identity()
            feature_dim = 2048

        self.fc = nn.Linear(feature_dim, args.num_classes)
    
    def forward(self, x):
        if not self.inv_linear_probe:
            x = self.backbone(x)
            out = self.fc(x)
        else:
            RX = self.backbone(x)
            b,c = RX.shape
            RX_type_eqv = self.in_type(RX.reshape([b,c,1,1]))
            eqv_logit = self.predictor_eqv(RX_type_eqv).tensor.flatten(1)
            if self.connector:
                eqv_score = self.connector(eqv_logit)
                RX_re = RX.reshape([b, c//self.order, self.order])
                permuted_reprs = [torch.roll(RX_re, shifts=-i, dims=2).reshape([b,c]) for i in range(self.order)]
                permuted_reprs = torch.stack(permuted_reprs, dim=-1)
                HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
            else:
                eqv_idx = torch.argmax(eqv_logit, dim=1)
                batch_perm = self.permute_tensor[eqv_idx].unsqueeze(1).expand(-1,c//self.order,-1)
                RX_re = RX.reshape([b, c//self.order, self.order])
                RX_re = RX_re.gather(2, batch_perm)
                HX = RX_re.reshape([b,c])
            out = self.fc(HX)

        return out

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)
    
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

if __name__ == '__main__':
    main()
