import os
import argparse
import json
import sys
import time
import torch
import torchvision
from pathlib import Path
from torch import nn, optim
from torchvision import datasets, transforms
from trainer import ESCNNEncoder, PredictorEqv, ProjectionMLP, Identity
from escnn import nn as enn
from escnn import gspaces
from datasets import *

parser = argparse.ArgumentParser(description='Evaluate resnet50 features on ImageNet')
parser.add_argument('--data', default='/data1/home/data/stl10',type=Path, metavar='DIR',
                    help='path to dataset')
parser.add_argument('--pretrained-dir', default='./experiments/stl10_escnn18',
                    type=Path, metavar='FILE',
                    help='path to pretrained model')
parser.add_argument('--weights', default='freeze', type=str,
                    choices=('finetune', 'freeze'),
                    help='finetune or freeze resnet weights')
parser.add_argument('--train-percent', default=100, type=int,
                    choices=(100, 10, 1),
                    help='size of traing set in percent')
parser.add_argument('--workers', default=8, type=int, metavar='N',
                    help='number of data loader workers')
parser.add_argument('--epochs', default=100, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--batch-size', default=256, type=int, metavar='N',
                    help='mini-batch size')
parser.add_argument('--LUT_lr', default=0,
                    help='multistep to decay learning rate')
# parser.add_argument('--LUT_lr', default=[(35, 0.1),(70, 0.02),(85, 0.004),(100, 0.0008)],
#                     help='multistep to decay learning rate')
parser.add_argument('--lr-backbone', default=0.0, type=float, metavar='LR',
                    help='backbone base learning rate')
parser.add_argument('--lr-classifier', default=1.0, type=float, metavar='LR',
                    help='classifier base learning rate')
parser.add_argument('--schedule', default=[30,40,50], nargs="*", type=int, help="learning rate schedule (when to drop lr by a ratio)")
parser.add_argument('--weight-decay', default=1e-6, type=float, metavar='W',
                    help='weight decay')
parser.add_argument('--print-freq', default=10, type=int, metavar='N',
                    help='print frequency')
parser.add_argument('--gpus', default='0', type=str)
parser.add_argument('--arch', default='escnn18', type=str, help='model architecture',
                    choices=['resnet18', 'resnet50', 'escnn18', 'escnn50', 'nin'])
parser.add_argument('--num-classes', default=196, type=int)
parser.add_argument('--pretrain-set', default='stl10', type=str, help='pretrain dataset',
                    choices=['stl10','stl10-R','imagenet100','imagenet100-R','caltech256','caltech256-R', 'cifar10', 'cifar10-R', 'cifar10-essl', 'cifar10-essl-R'])
parser.add_argument('--eval-set', default='stl10', type=str, help='evaluation dataset',
                    choices=['stl10','imagenet100','stanford_cars','fgvc_aircraft','cub_200_2011','cifar10','cifar100','caltech256','MTARSI'])
parser.add_argument('--connector', default='softmax', type=str, help='equivariance connection map',
                    choices=['softmax', 'identity', 'tanh', 'shift'])
parser.add_argument('--ssl', type=str, help='ssl model',
                    choices=['simclr', 'simsiam', 'moco'])
parser.add_argument('--qk', default='', choices=['', 'q', 'k'], help='When using MOCO, between query and key, set the bacbone encoder during linear probing')
parser.add_argument('--use_gpool', action='store_true')
parser.add_argument('--rotated', action='store_true')
parser.add_argument('--inv_linear_probe', action='store_true')
parser.add_argument('--save_linear_probe', action='store_true')
parser.add_argument('--random_rotation', action='store_true', help='To have Dataset with Random Rotation Augmentation')
parser.add_argument('--adjust_learning_rate', action='store_true')
parser.add_argument('--iterations', default=0, type=int)
parser.add_argument('--final_name', default='final.pth', type=str)
parser.add_argument('--N', default=4 ,type=int)
parser.add_argument('--circular_transform', default=False, action='store_true')
parser.add_argument('--circular_range', default=10, type=int)
parser.add_argument('--cifar_layer', default=3, type=int, choices=[0,1,2,3])
parser.add_argument('--nin_classifier', default=False, action='store_true')

def main():
    args = parser.parse_args()
    args.ngpus_per_node = torch.cuda.device_count()
    args.gpus = [int(x) for x in args.gpus.split(',')]
    main_worker(args)

def main_worker(args):
    if args.rotated:        
        if not args.inv_linear_probe:
            statfile_name = args.eval_set + f'_stats_linear_probe_rotated_{args.iterations}_{args.qk}_{args.final_name[:-4]}.txt'
        else:
            statfile_name = args.eval_set + f'_stats_linear_probe_rotated_inv_{args.iterations}_{args.qk}_{args.final_name[:-4]}.txt'
    else:
        if not args.inv_linear_probe:
            statfile_name = args.eval_set + f'_stats_linear_probe_{args.iterations}_{args.qk}_{args.final_name[:-4]}.txt'
        else:
            statfile_name = args.eval_set + f'_stats_linear_probe_inv_{args.iterations}_{args.qk}_{args.final_name[:-4]}.txt'
    stats_file = open(args.pretrained_dir / statfile_name, 'a', buffering=1)
    torch.backends.cudnn.benchmark = True

    # Data loading code
    if not args.random_rotation:
        train_dataset, val_dataset, num_classes = load_eval_datasets(args)

    else:
        train_dataset, val_dataset, num_classes = load_eval_random_rotation_sets(args)
    
    args.num_classes = num_classes

    kwargs = dict(batch_size=args.batch_size, num_workers=args.workers, pin_memory=True)
    train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, **kwargs)

    model = Encoder(args)
    model.cuda()

    if args.ssl == 'moco':
        if args.qk == 'q':
            state_dict = torch.load(os.path.join(args.pretrained_dir, args.final_name), map_location='cpu')
            missing_keys, unexpected_keys = model.backbone.load_state_dict(state_dict['backbone_q'], strict=False)

        elif args.qk == 'k':
            state_dict = torch.load(os.path.join(args.pretrained_dir, args.final_name), map_location='cpu')
            missing_keys, unexpected_keys = model.backbone.load_state_dict(state_dict['backbone_k'], strict=False)
    
    else:
        state_dict = torch.load(os.path.join(args.pretrained_dir, args.final_name), map_location='cpu')
        missing_keys, unexpected_keys = model.backbone.load_state_dict(state_dict['backbone'], strict=False)
    
    print(missing_keys)
    print(unexpected_keys)
    
    if args.inv_linear_probe:

        if args.ssl == 'moco':
            if args.qk == 'q':
                missing_keys, unexpected_keys = model.predictor_eqv.load_state_dict(state_dict['predictor_eqv_q'], strict=False)

            elif args.qk == 'k':
                missing_keys, unexpected_keys = model.predictor_eqv.load_state_dict(state_dict['predictor_eqv_k'], strict=False)

        else:
            missing_keys, unexpected_keys = model.predictor_eqv.load_state_dict(state_dict['predictor_eqv'], strict=False)
    
    print(missing_keys)
    print(unexpected_keys)


    if args.nin_classifier:
        normalize_weights(model.fc)
    else:
        model.fc.weight.data.normal_(mean=0.0, std=0.01)
        model.fc.bias.data.zero_()
        
    if args.weights == 'freeze':
        model.backbone.requires_grad_(False)
        model.fc.requires_grad_(True)
        if args.inv_linear_probe:
            model.predictor_eqv.requires_grad_(False)
            # model.projector.requires_grad_(False)
    
    classifier_parameters, model_parameters = [], []
    if args.nin_classifier:
        classifier_names = {'fc.Linear_1.weight', 'fc.BatchNorm_1.weight', 'fc.BatchNorm_1.bias',
                            'fc.Linear_2.weight', 'fc.BatchNorm_2.weight', 'fc.BatchNorm_2.bias',
                            'fc.Final_fc.weight', 'fc.Final_fc.bias'}
    else:
        classifier_names = {'fc.weight', 'fc.bias'}
    for name, param in model.named_parameters():
        if name in classifier_names:
            classifier_parameters.append(param)
        else:
            model_parameters.append(param)
            
    criterion = nn.CrossEntropyLoss().cuda()

    param_groups = [dict(params=classifier_parameters, lr=args.lr_classifier)]
    if args.weights == 'finetune':
        param_groups.append(dict(params=model_parameters, lr=args.lr_backbone))
    
    if args.pretrain_set == 'cifar10' or 'cifar10-essl':
        optimizer = optim.SGD(param_groups, 0, momentum=0.9, weight_decay=5e-4, nesterov=True)  
    else:
        optimizer = optim.SGD(param_groups, 0, momentum=0.9, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

    start_epoch = 0
    best_acc = argparse.Namespace(top1=0, top5=0)

    start_time = time.time()
    for epoch in range(start_epoch, args.epochs):
        # train
        if args.weights == 'finetune':
            model.train()
        elif args.weights == 'freeze':
            model.backbone.eval()
            if args.inv_linear_probe:
                model.predictor_eqv.eval()
                # model.projector.eval()
        else:
            assert False
        
        if args.LUT_lr == 0:
            if args.adjust_learning_rate:
                adjust_learning_rate(optimizer, epoch, args)
        else:
            adjust_learning_rate_LUT(optimizer, epoch, args.LUT_lr)
        
        for step, (images, target) in enumerate(train_loader, start=epoch * len(train_loader)):
            output = model(images.cuda(non_blocking=True))
            loss = criterion(output, target.cuda(non_blocking=True))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step % args.print_freq == 0:
                pg = optimizer.param_groups
                lr_classifier = pg[0]['lr']
                lr_backbone = pg[1]['lr'] if len(pg) == 2 else 0
                stats = dict(epoch=epoch, step=step, lr_backbone=lr_backbone,
                                lr_classifier=lr_classifier, loss=loss.item(),
                                time=int(time.time() - start_time))
                print(json.dumps(stats))
                print(json.dumps(stats), file=stats_file)
        
        if args.save_linear_probe:  
            state = dict(epoch=epoch+1, model=model.fc.state_dict(),optimizer=optimizer.state_dict())    
            if args.rotated:
                if args.inv_linear_probe:
                    torch.save(state, args.pretrained_dir / f'checkcpoint_fc_rotated_hx_{args.iterations}_{args.qk}_{args.final_name[:-4]}.pth')
                elif args.use_gpool:
                    torch.save(state, args.pretrained_dir / f'checkcpoint_fc_rotated_inv_{args.iterations}_{args.qk}_{args.final_name[:-4]}.pth')
                else: 
                    torch.save(state, args.pretrained_dir / f'checkcpoint_fc_rotated_{args.iterations}_{args.qk}_{args.final_name[:-4]}.pth')
            else:
                if args.inv_linear_probe:
                    torch.save(state, args.pretrained_dir / f'checkcpoint_fc_hx_{args.iterations}_{args.qk}_{args.final_name[:-4]}.pth')
                elif args.use_gpool:
                    torch.save(state, args.pretrained_dir / f'checkcpoint_fc_inv_{args.iterations}_{args.qk}_{args.final_name[:-4]}.pth')
                else:
                    torch.save(state, args.pretrained_dir / f'checkcpoint_fc_{args.iterations}_{args.qk}_{args.final_name[:-4]}.pth')

        # evaluate
        model.eval()
        top1 = AverageMeter('Acc@1')
        top5 = AverageMeter('Acc@5')
        with torch.no_grad():
            for images, target in val_loader:
                output = model(images.cuda(non_blocking=True))
                acc1, acc5 = accuracy(output, target.cuda(non_blocking=True), topk=(1, 5))
                top1.update(acc1[0].item(), images.size(0))
                top5.update(acc5[0].item(), images.size(0))
        best_acc.top1 = max(best_acc.top1, top1.avg)
        best_acc.top5 = max(best_acc.top5, top5.avg)
        stats = dict(epoch=epoch, acc1=top1.avg, acc5=top5.avg, best_acc1=best_acc.top1, best_acc5=best_acc.top5)
        print(json.dumps(stats))
        print(json.dumps(stats), file=stats_file)
        
    if args.save_linear_probe: 
        state = dict(fc=model.fc.state_dict())
        if args.rotated:
            if args.inv_linear_probe:
                torch.save(state, args.pretrained_dir / f'final_fc_rotated_hx_{args.iterations}_{args.qk}_{args.final_name[:-4]}.pth')
            elif args.use_gpool:
                torch.save(state, args.pretrained_dir / f'final_fc_rotated_inv_{args.iterations}_{args.qk}_{args.final_name[:-4]}.pth')
            else:
                torch.save(state, args.pretrained_dir / f'final_fc_rotated_{args.iterations}_{args.qk}_{args.final_name[:-4]}.pth')
        else:
            if args.inv_linear_probe:
                torch.save(state, args.pretrained_dir / f'final_fc_hx_{args.iterations}_{args.qk}_{args.final_name[:-4]}.pth')
            elif args.use_gpool:
                torch.save(state, args.pretrained_dir / f'final_fc_inv_{args.iterations}_{args.qk}_{args.final_name[:-4]}.pth')
            else:
                torch.save(state, args.pretrained_dir / f'final_fc_{args.iterations}_{args.qk}_{args.final_name[:-4]}.pth')

class Encoder(nn.Module):

    def __init__(self, args):
        super().__init__()
        self.inv_linear_probe = False
        if args.arch=='escnn18' or args.arch=='escnn50' or args.arch=='nin':
            self.backbone = ESCNNEncoder(args.arch, use_gpool=args.use_gpool)

            if args.use_gpool:
                feature_dim = self.backbone.num_out_trivial_repr
            else:
                num_out_regular_repr = self.backbone.num_out_regular_repr   # resnet18, resnet34
                self.order = args.N
                feature_dim = num_out_regular_repr * self.order

                if args.inv_linear_probe:
                    self.inv_linear_probe = args.inv_linear_probe
                    self.order = self.backbone.order
                    self.gspace = gspaces.rot2dOnR2(N=self.order)
                    self.in_type = enn.FieldType(self.gspace, num_out_regular_repr*[self.gspace.regular_repr])
                    hidden_type = enn.FieldType(self.gspace, 512*[self.gspace.regular_repr])
                    out_type_eqv = enn.FieldType(self.gspace, [self.gspace.regular_repr])
                    self.predictor_eqv = PredictorEqv(self.in_type, hidden_type, out_type_eqv)
                    # self.projector = ProjectionMLP(feature_dim, 2048, 128)
                    if args.connector=='softmax':
                        self.connector = torch.nn.Softmax(dim=1)
                    elif args.connector=='identity':
                        self.connector = Identity()
                    elif args.connector=='tanh':
                        self.connector = torch.nn.Tanh()
                    elif args.connector=='shift':
                        self.connector = None
                        permute_patterns = [torch.roll(torch.arange(self.order), shifts=-i).tolist() for i in range(self.order)]
                        self.permute_tensor = torch.tensor(permute_patterns).cuda()

        elif args.arch=='resnet18':
            self.backbone = torchvision.models.resnet18(zero_init_residual=True)
            self.backbone.fc = nn.Identity()
            feature_dim = 512
        elif args.arch=='resnet50':
            self.backbone = torchvision.models.resnet50(zero_init_residual=True)
            self.backbone.fc = nn.Identity()
            feature_dim = 2048

        if args.nin_classifier:
            self.fc = nn.Sequential()
            
            nFeats = min(args.num_classes*20, 2048)
            self.fc.add_module('Linear_1', nn.Linear(feature_dim, nFeats, bias=False))
            self.fc.add_module('BatchNorm_1', nn.BatchNorm1d(nFeats))
            self.fc.add_module('ReLU_1', nn.ReLU(inplace=True))
            self.fc.add_module('Linear_2', nn.Linear(nFeats, nFeats, bias=False))
            self.fc.add_module('BatchNorm_2', nn.BatchNorm1d(nFeats))
            self.fc.add_module('ReLU_2', nn.ReLU(inplace=True))
            self.fc.add_module('Final_fc', nn.Linear(nFeats, args.num_classes))
        else:
            self.fc = nn.Linear(feature_dim, args.num_classes)
    
    def forward(self, x):
        if not self.inv_linear_probe:
            x = self.backbone(x)
            out = self.fc(x)
        else:
            RX = self.backbone(x)
            b,c = RX.shape
            RX_type_eqv = self.in_type(RX.reshape([b,c,1,1]))
            eqv_logit = self.predictor_eqv(RX_type_eqv).tensor.flatten(1)
            if self.connector:
                eqv_score = self.connector(eqv_logit)
                RX_re = RX.reshape([b, c//self.order, self.order])
                permuted_reprs = [torch.roll(RX_re, shifts=-i, dims=2).reshape([b,c]) for i in range(self.order)]
                permuted_reprs = torch.stack(permuted_reprs, dim=-1)
                
                HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
            else:
                eqv_idx = torch.argmax(eqv_logit, dim=1)
                batch_perm = self.permute_tensor[eqv_idx].unsqueeze(1).expand(-1,c//4,-1)
                RX_re = RX.reshape([b, c//self.order, self.order])
                RX_re = RX_re.gather(2, batch_perm)
                HX = RX_re.reshape([b,c])
            out = self.fc(HX)

        return out
    
    
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def adjust_learning_rate_LUT(optimizer, iters, LUT):
    # decay learning rate by 'gamma' for every 'stepsize'
    for (stepvalue, base_lr) in LUT:
        if iters < stepvalue:
            lr = base_lr
            break

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate based on schedule"""
    lr = args.lr_classifier
    for milestone in args.schedule:
        lr *= 0.1 if epoch >= milestone else 1.0
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def normalize_weights(model):
    for layer in model:
        if isinstance(layer, nn.Linear):
            with torch.no_grad():
                fin = layer.in_features
                fout = layer.out_features
                std_val = np.sqrt(2.0/fout)
                layer.weight.data.normal_(0.0, std_val)
                if layer.bias is not None:
                    layer.bias.fill_(0.0)

if __name__ == '__main__':
    main()
