# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import json
import os
from pathlib import Path
import torch
from omegaconf import OmegaConf

from solo.args.umap import parse_args_umap
from solo.data.classification_dataloader import prepare_data
from solo.methods import METHODS
from solo.utils.auto_umap import OfflineUMAP
from solo.methods.base import BaseMethod
import torch.nn as nn
import torch
from torch.optim import lr_scheduler
import copy
from torch import cuda, nn, optim
from tqdm import tqdm, trange
import numpy
from torch.nn.functional import normalize
from torch.autograd import Variable

def l2_reg_ortho_loss_func(mdl,device,weight = 1e-2,method='risp'):
        l2_reg = None

        for W in mdl.parameters():
                if W.ndimension() < 2:
                        continue
                else:
                        cols = W[0].numel()
                        rows = W.shape[0]
                       

                        if method =='risp':
                                w1 = W.reshape(-1,cols)
                                wt = torch.transpose(w1,0,1).contiguous()
                                if (rows > cols):
                                        m  = torch.matmul(wt,w1)
                                        ident = Variable(torch.eye(cols,cols),requires_grad=True)
                                else:
                                        m = torch.matmul(w1,wt)
                                        ident = Variable(torch.eye(rows,rows), requires_grad=True)

                                ident = ident.to(device)
                                w_tmp = (m - ident)
                                b_k = Variable(torch.rand(w_tmp.shape[1],1))
                                b_k = b_k.to(device)

                                v1 = torch.matmul(w_tmp, b_k)
                                norm1 = torch.norm(v1,2)
                                v2 = torch.div(v1,norm1)
                                v3 = torch.matmul(w_tmp,v2)

                                if l2_reg is None:
                                        l2_reg = (torch.norm(v3,2))**2
                                else:
                                        l2_reg = l2_reg + (torch.norm(v3,2))**2
                        elif method =='so':
                                w1 = W.reshape(-1,cols)
                                wt = torch.transpose(w1,0,1).contiguous()
                                if (rows > cols):
                                        m  = torch.matmul(wt,w1)
                                        ident = Variable(torch.eye(cols,cols),requires_grad=True)
                                else:
                                        m = torch.matmul(w1,wt)
                                        ident = Variable(torch.eye(rows,rows), requires_grad=True)

                                ident = ident.to(device)
                                w_tmp = (m - ident)

                                if l2_reg is None:
                                        l2_reg = (torch.norm(w_tmp ,2))**2
                                else:
                                        l2_reg = l2_reg + (torch.norm(w_tmp,2))**2

                        else:
                                print('wrong method')


def main():
    

    # build the model
    models = ['resnet18','resnet50','wide_resnet28w2','vit_tiny','vit_small','vit_base']
    for model in models:
        backbone_model = BaseMethod._BACKBONES[model]
        backbone = backbone_model(method='byol')

        device = "cuda:3"


    
        if model.startswith("resnet"):
        # remove fc layer
            backbone.fc = nn.Identity()

        times_1 = []
        times_2 = []
        import time
        for i in range(5):
            time_1 = time.time()
            loss_so = l2_reg_ortho_loss_func(backbone,device=device,method='so')
            time_cost =  time.time() - time_1
            times_1.append(time_cost)

            time_1 = time.time()
            loss_risp = l2_reg_ortho_loss_func(backbone,device=device,method='risp')
            time_cost =  time.time() - time_1
            times_2.append(time_cost)

        times_1 = numpy.array(times_1)
        times_2 = numpy.array(times_2)
        print('{} mean:{} std:{}'.format(model,time_1.mean(),times_1.std()))
            
            
            


    


if __name__ == "__main__":
    main()
