import math
import torch
import torch.nn as nn
import sys
from torch import nn, optim
import time
import numpy as np
import random
import os
import scipy.io as sio
from torch.nn import functional as F
from scipy.io import savemat
from module import SparseCLModel
from util import Pack, LossManager, LARS
from util import adjust_learning_rate, adjust_moco_momentum
import pandas as pd
import util
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchvision
from dataloader import CIFARPairTransform, CIFARSingleTransform, CIFARSingleTransform2
from metric import align_loss, uniform_loss
import torchvision

class SparseCL(nn.Module):
    def __init__(self, args):
        super(SparseCL, self).__init__()
        self.args = args
        self.model = SparseCLModel(args).cuda()
        if args.dataset == "CIFAR-10":
            ##for alignment metric calculation
            train_data = torchvision.datasets.CIFAR10(root=args.data_dir, train=True, transform=CIFARPairTransform(train_transform=True), download=True)
            val_data = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, transform=CIFARPairTransform(train_transform=True), download=True)
            ##for uniformity metric calculation
            memory_train_data = torchvision.datasets.CIFAR10(root=args.data_dir, train=True, transform=CIFARPairTransform(train_transform=False), download=True)
            memory_val_data = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, transform=CIFARPairTransform(train_transform=False), download=True)
            ##for linear evaluation
            linear_train_data = torchvision.datasets.CIFAR10(root=args.data_dir, train=True, transform=CIFARSingleTransform(train_transform=True), download=True)
            linear_val_data = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, transform=CIFARSingleTransform(train_transform=False), download=True)  
        elif args.dataset == "CIFAR-100":
            ##for alignment metric calculation
            train_data = torchvision.datasets.CIFAR100(root=args.data_dir, train=True, transform=CIFARPairTransform(train_transform=True), download=True)
            val_data = torchvision.datasets.CIFAR100(root=args.data_dir, train=False, transform=CIFARPairTransform(train_transform=True), download=True)
            ##for uniformity metric calculation
            memory_train_data = torchvision.datasets.CIFAR100(root=args.data_dir, train=True, transform=CIFARPairTransform(train_transform=False), download=True)
            memory_val_data = torchvision.datasets.CIFAR100(root=args.data_dir, train=False, transform=CIFARPairTransform(train_transform=False), download=True)
            ##for linear evaluation
            linear_train_data = torchvision.datasets.CIFAR100(root=args.data_dir, train=True, transform=CIFARSingleTransform(train_transform=True), download=True)
            linear_val_data = torchvision.datasets.CIFAR100(root=args.data_dir, train=False, transform=CIFARSingleTransform(train_transform=False), download=True)

        self.train_dataloader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
        self.val_dataloader = DataLoader(val_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
        self.memory_train_dataloader = DataLoader(memory_train_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
        self.memory_val_dataloader = DataLoader(memory_val_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
        self.linear_train_dataloader = DataLoader(linear_train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
        self.linear_val_dataloader = DataLoader(linear_val_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    def train_model(self, args): 
        self.train()
        param_weights = []
        param_biases = []
        for param in self.model.online_encoder.parameters():
            if param.ndim == 1:
                param_biases.append(param)
            else:
                param_weights.append(param)

        if args.use_predictor == True:
            for param in self.model.predictor.parameters():
                if param.ndim == 1:
                    param_biases.append(param)
                else:
                    param_weights.append(param)

        parameters = [{'params': param_weights}, {'params': param_biases}]
        optimizer = LARS(parameters, args.lr, weight_decay=args.weight_decay, momentum=args.momentum, trust_coefficient=args.eta)

        results = {'total_loss':[], 'alignment_loss': [], 'sparsity_loss':[]}
        start = time.time()
        train_loss = LossManager()
        iters_per_epoch = len(self.train_dataloader)
        for epoch in range(1, args.epochs + 1):
            print("epoch:%d, lr:%4f"%(epoch, optimizer.param_groups[0]["lr"]))
            total_loss, alignment_loss, sparsity_loss, total_num = 0.0, 0.0, 0.0, 0
            for step, ((x1, x2), _) in enumerate(self.train_dataloader):
                batch_size = x1.size(0)
                x1, x2 = x1.cuda(non_blocking=True), x2.cuda(non_blocking=True)
                lr = adjust_learning_rate(optimizer, (epoch-1) + step / iters_per_epoch, args)
                momentum_rate = adjust_moco_momentum((epoch-1) + step / iters_per_epoch, args)
                loss, loss_pack = self.model(x1, x2, momentum_rate) 

                optimizer.zero_grad() 
                loss.backward()
                optimizer.step()
                train_loss.add_loss(loss_pack)

                total_num += batch_size
                total_loss += loss.item() * batch_size
                alignment_loss += loss_pack.alignment_loss.item() * batch_size
                sparsity_loss += loss_pack.sparsity_loss.item() * batch_size
                if (step+1)%10==0:
                    print(train_loss.pprint(window=30, prefix='Train Epoch: [{}/{}] Iters:[{}/{}]'.format(epoch, args.epochs, step+1, len(self.train_dataloader))))

            train_loss.clear()
            results['total_loss'].append(total_loss/total_num)
            results['alignment_loss'].append(alignment_loss/total_num)
            results['sparsity_loss'].append(sparsity_loss/total_num)
            
            #save statistics
            data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
            data_frame.to_csv('{}/{}_statistics.csv'.format(args.saver_dir, args.save_name_pre), index_label='epoch')

        checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint-'+args.save_name_pre+'_model.pth')
        torch.save(self.model.state_dict(), checkpoint_path)
        self.eval()

    def calc_uniformity(self, args):
        self.eval()
        z_norm_bank = []
        with torch.no_grad():
            for (data, _), target in tqdm(self.memory_val_dataloader, desc='calculation of uniformity in val dataset'):
                _, z_norm = self.model.obtain_representation(data.cuda(non_blocking=True))
                z_norm_bank.append(z_norm)
                
            z_norm_bank = torch.cat(z_norm_bank, dim=0).contiguous()
            uniformity = uniform_loss(z_norm_bank) 
            
            print('Uniformity:{:.4f}'.format(uniformity))
        self.train()
        return uniformity.item()

    def calc_alignment(self, args):
        self.eval()
        z1_bank, z2_bank = [], []
        with torch.no_grad():
            for (x1, x2), target in tqdm(self.val_dataloader, desc='calculation of alignment in val dataset'):
                _, z1_norm = self.model.obtain_representation(x1.cuda(non_blocking=True))
                _, z2_norm = self.model.obtain_representation(x2.cuda(non_blocking=True))
                z1_bank.append(z1_norm)
                z2_bank.append(z2_norm)

        z1_bank = torch.cat(z1_bank, dim=0).contiguous()
        z2_bank = torch.cat(z2_bank, dim=0).contiguous()
        alignment = align_loss(z1_bank, z2_bank)
        print('Alignment:{:.4f}'.format(alignment))
        self.train()
        return alignment.item()

    def linear_train_val(self, args, epoch, optimizer, criterion, model, is_train=True):
        if is_train==True:
            data_loader = self.linear_train_dataloader
        else:
            data_loader = self.linear_val_dataloader

        model.eval()
        total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader)
        with (torch.enable_grad() if is_train else torch.no_grad()):
            for data, target in data_bar:
                data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
                out = model(data)
                loss = criterion(out, target)

                if is_train:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                total_num += data.size(0)
                total_loss += loss.item() * data.size(0)
                prediction = torch.argsort(out, dim=-1, descending=True)
                total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
                total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()

                data_bar.set_description('{} Epoch: [{}/{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}%'
                                     .format('Train' if is_train else 'Val', epoch, args.linear_epochs, total_loss / total_num,
                                             total_correct_1 / total_num * 100, total_correct_5 / total_num * 100,))

        return total_loss / total_num, total_correct_1 / total_num * 100, total_correct_5 / total_num * 100

    def linear_model(self, args):
        model = torchvision.models.resnet.resnet18()
        model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False)
        model.maxpool = nn.Identity()
        del model.fc

        checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint-'+args.save_name_pre+'_model.pth')
        state_dict = torch.load(checkpoint_path, map_location='cpu')
        for k in list(state_dict.keys()):
            # retain only base_encoder up to before the embedding layer
            if k.startswith('online_encoder') and not k.startswith('online_encoder.%s' % 'fc'):
                # remove prefix
                state_dict[k[len("online_encoder."):]] = state_dict[k]
            # delete renamed or unused k
            del state_dict[k]
        model.load_state_dict(state_dict, strict=True)

        model.fc = nn.Linear(512, args.num_class, bias=True)
        model.fc.weight.data.normal_(mean=0.0, std=0.01)
        model.fc.bias.data.zero_()
        model = model.cuda()
        model.eval()

        for name, param in model.named_parameters():
            if name not in ['%s.weight' % 'fc', '%s.bias' % 'fc']:
                param.requires_grad = False

        parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
        assert len(parameters) == 2  # weight, bias

        results = {'train_loss': [], 'train_acc@1': [], 'train_acc@5': [], 'test_loss': [], 'test_acc@1': [], 'test_acc@5': []}
        final_results = {'train_loss': 0, 'train_acc@1': 0, 'train_acc@5': 0, 'test_loss': 0, 'test_acc@1': 0, 'test_acc@5': 0}

        optimizer = optim.SGD(parameters, args.linear_lr, momentum=0.9, weight_decay=1e-6)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.linear_epochs)
        loss_criterion = nn.CrossEntropyLoss().cuda()
        save_name = os.path.join(args.results, args.save_name_pre+"_"+str(args.linear_lr)+'_linear.csv')
    
        for epoch in range(1, args.linear_epochs + 1):
            print("epoch:%d, lr:%4f"%(epoch, optimizer.param_groups[0]["lr"]))
            train_loss, train_acc_1, train_acc_5 = self.linear_train_val(args, epoch, optimizer, loss_criterion, model, is_train=True)
            results['train_loss'].append(train_loss)
            results['train_acc@1'].append(train_acc_1)
            results['train_acc@5'].append(train_acc_5)
            test_loss, test_acc_1, test_acc_5 = self.linear_train_val(args, epoch, optimizer, loss_criterion, model, is_train=False)
            results['test_loss'].append(test_loss)
            results['test_acc@1'].append(test_acc_1)
            results['test_acc@5'].append(test_acc_5)
            # save statistics
            data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
            data_frame.to_csv(save_name, index_label='epoch')
            if test_acc_1>final_results['test_acc@1']:
                final_results['train_loss'] = train_loss
                final_results['train_acc@1'] = train_acc_1
                final_results['train_acc@5'] = train_acc_5
                final_results['test_loss'] = test_loss
                final_results['test_acc@1'] = test_acc_1
            if test_acc_5>final_results['test_acc@5']:
                final_results['test_acc@5'] = test_acc_5
            scheduler.step()

        results['train_loss'].append(final_results['train_loss'])
        results['train_acc@1'].append(final_results['train_acc@1'])
        results['train_acc@5'].append(final_results['train_acc@5'])
        results['test_loss'].append(final_results['test_loss'])
        results['test_acc@1'].append(final_results['test_acc@1'])
        results['test_acc@5'].append(final_results['test_acc@5'])

        data_frame = pd.DataFrame(data=results, index=range(1, args.linear_epochs + 2))
        data_frame.to_csv(save_name, index_label='epoch')


    