import os
import math
import time
import copy
import argparse
import random

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as TF

from resnet import resnet18
from utils import knn_monitor, fix_seed, count_parameters_in_MB, CircularCrop

from escnn import nn as enn
from escnn import gspaces
from equivision.models import c4resnet18, c8resnet18

normalize = T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
single_transform = T.Compose([
    T.ToTensor(), 
    normalize
    ])


class ContrastiveLearningTransform:
    def __init__(self):
        transforms = [
            T.RandomResizedCrop(size=33, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomApply([T.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            T.RandomGrayscale(p=0.2)
        ]

        self.transform = T.Compose(transforms)

    def __call__(self, x):
        output = [
            single_transform(self.transform(x)),
            single_transform(self.transform(x)),
        ]
        return output

    
class ContrastiveLearningRotAugTransform:
    def __init__(self):
        transforms = [
            T.RandomResizedCrop(size=33, scale=(0.2, 1.0)),
            CircularCrop(0),
            T.RandomRotation((-45, 45), interpolation=T.InterpolationMode.BILINEAR),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomApply([T.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            T.RandomGrayscale(p=0.2)
        ]

        self.transform = T.Compose(transforms)

    def __call__(self, x):
        output = [
            single_transform(self.transform(x)),
            single_transform(self.transform(x)),
        ]
        return output



def adjust_learning_rate(epochs, warmup_epochs, base_lr, optimizer, loader, step):
    max_steps = epochs * len(loader)
    warmup_steps = warmup_epochs * len(loader)
    if step < warmup_steps:
        lr = base_lr * step / warmup_steps
    else:
        step -= warmup_steps
        max_steps -= warmup_steps
        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
        end_lr = 0
        lr = base_lr * q + end_lr * (1 - q)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


def negative_cosine_similarity_loss(p, z):
    return - F.cosine_similarity(p, z.detach(), dim=-1).mean()


def info_nce_loss(z1, z2, temperature=0.5):
    z1 = torch.nn.functional.normalize(z1, dim=1)
    z2 = torch.nn.functional.normalize(z2, dim=1)

    logits = z1 @ z2.T
    logits /= temperature
    n = z2.shape[0]
    labels = torch.arange(0, n, dtype=torch.long).cuda()
    loss = torch.nn.functional.cross_entropy(logits, labels)
    return loss

def cross_entropy_AB(A, B, eps=1e-6):

    b = A.shape[0]
    A = torch.softmax(A, dim=1)
    B = torch.softmax(B, dim=1)      
    loss = -(A * torch.log(B + eps)).sum(1)
    loss = loss.sum(0) / b
    return loss

class PredictorEqv(nn.Module):
    def __init__(self, in_type, hidden_type, out_type):
        super().__init__()
        self.net = enn.SequentialModule(
            enn.R2Conv(in_type, hidden_type, kernel_size=1, stride=1, padding=0),
            enn.InnerBatchNorm(hidden_type),
            enn.ReLU(hidden_type),
            enn.R2Conv(hidden_type, out_type, kernel_size=1, stride=1, padding=0),
            )

    def forward(self, x):
        return self.net(x)


class ProjectionMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(in_dim, hidden_dim, bias=False),
                                 nn.BatchNorm1d(hidden_dim),
                                 nn.ReLU(inplace=True),
                                 nn.Linear(hidden_dim, out_dim, bias=False),
                                 nn.BatchNorm1d(out_dim, affine=False))

    def forward(self, x):
        return self.net(x)


class PredictionMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x):
        return self.net(x)


class ESCNNEncoder(nn.Module):

    def __init__(self, guide=False, use_gpool=False, N=4):
        super().__init__()
        if N==4:
            self.backbone = c4resnet18(pretrained=False, use_gpool=use_gpool)
        elif N==8:
            self.backbone = c8resnet18(pretrained=False, use_gpool=use_gpool)
        self.order = self.backbone.order
        self.num_out_regular_repr = self.backbone.base_width * 8
        self.feature_dim = self.num_out_regular_repr * self.order
        if use_gpool:
            self.feature_dim = self.num_out_regular_repr
        
        self.guide = guide
        if guide:
            self.gspace = gspaces.rot2dOnR2(N=self.order)
            self.in_type = enn.FieldType(self.gspace, self.num_out_regular_repr*[self.gspace.regular_repr])
            hidden_type = enn.FieldType(self.gspace, 512*[self.gspace.regular_repr])
            out_type_eqv = enn.FieldType(self.gspace, [self.gspace.regular_repr])
            self.predictor_eqv = PredictorEqv(self.in_type, hidden_type, out_type_eqv)
            self.connector = torch.nn.Softmax(dim=1)

    def extract_guided_output(self, x):
        x = self.backbone(x).tensor
        RX = x.view(x.size(0), -1)
        if self.guide:
            b,c = RX.shape
            RX_type_eqv = self.in_type(RX.reshape([b,c,1,1]))
            eqv_logit = self.predictor_eqv(RX_type_eqv).tensor.flatten(1)
            eqv_score = self.connector(eqv_logit)
            RX_re = RX.reshape([b, c//self.order, self.order])
            permuted_reprs = [torch.roll(RX_re, shifts=-i, dims=2).reshape([b,c]) for i in range(self.order)]
            permuted_reprs = torch.stack(permuted_reprs, dim=-1)
            HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
            return eqv_logit, HX
        return RX

    def forward(self, x):
        x = self.backbone(x).tensor
        RX = x.view(x.size(0), -1)
        if self.guide:
            b,c = RX.shape
            RX_type_eqv = self.in_type(RX.reshape([b,c,1,1]))
            eqv_logit = self.predictor_eqv(RX_type_eqv).tensor.flatten(1)
            eqv_score = self.connector(eqv_logit)
            RX_re = RX.reshape([b, c//self.order, self.order])
            permuted_reprs = [torch.roll(RX_re, shifts=-i, dims=2).reshape([b,c]) for i in range(self.order)]
            permuted_reprs = torch.stack(permuted_reprs, dim=-1)
            HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
            return HX

        return RX



class Branch(nn.Module):
    def __init__(self, args, encoder=None):
        super().__init__()
        dim_proj = [int(x) for x in args.dim_proj.split(',')]
        if encoder:
            self.encoder = encoder
        else:
            self.encoder = ESCNNEncoder(guide=args.guide, use_gpool=args.use_gpool, N=args.order)
        self.order = self.encoder.order
        feature_dim = self.encoder.feature_dim

        self.projector = ProjectionMLP(feature_dim, dim_proj[0], dim_proj[1])
    
    def forward(self, x):
        feature = self.encoder(x)
        out = self.projector(feature)
        return out



def knn_loop(encoder, train_loader, test_loader):
    accuracy = knn_monitor(net=encoder.cuda(),
                           memory_data_loader=train_loader,
                           test_data_loader=test_loader,
                           device='cuda',
                           k=200,
                           hide_progress=True)
    return accuracy


def ssl_loop(args, encoder=None):
    if args.checkpoint_path:
        print('checkpoint provided => moving to evaluation')
        main_branch = Branch(args, encoder=encoder).cuda()
        saved_dict = torch.load(os.path.join(args.checkpoint_path))['state_dict']
        main_branch.load_state_dict(saved_dict)
        file_to_update = open(os.path.join(args.path_dir, 'train_and_eval.log'), 'a')
        file_to_update.write(f'evaluating {args.checkpoint_path}\n')
        return main_branch.encoder, file_to_update

    # logging
    os.makedirs(args.path_dir, exist_ok=True)
    file_to_update = open(os.path.join(args.path_dir, 'train_and_eval.log'), 'w')

    # dataset
    if args.rotaug:
        train_transform = ContrastiveLearningRotAugTransform
        train_dataset = torchvision.datasets.CIFAR10(
            '../data', train=True, transform=train_transform(), download=True
        )
    else:
        train_transform = ContrastiveLearningTransform
        train_dataset = torchvision.datasets.CIFAR10(
            '../data', train=True, transform=train_transform(), download=True
        )
        
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        shuffle=True,
        batch_size=args.bsz,
        pin_memory=True,
        num_workers=args.num_workers,
        drop_last=True
    )
    memory_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.CIFAR10(
            '../data', train=True, transform=single_transform, download=True
        ),
        shuffle=False,
        batch_size=args.bsz,
        pin_memory=True,
        num_workers=args.num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.CIFAR10(
            '../data', train=False, transform=single_transform, download=True,
        ),
        shuffle=False,
        batch_size=args.bsz,
        pin_memory=True,
        num_workers=args.num_workers
    )

    # models

    main_branch = Branch(args, encoder=encoder).cuda()

    if args.loss == 'simsiam':
        dim_proj = [int(x) for x in args.dim_proj.split(',')]
        predictor = PredictionMLP(dim_proj[1], args.dim_pred, dim_proj[1]).cuda()

    # optimization
    optimizer = torch.optim.SGD(
        main_branch.parameters(),
        momentum=0.9,
        lr=args.lr * args.bsz / 256,
        weight_decay=args.wd
    )

    if args.loss == 'simsiam':
        pred_optimizer = torch.optim.SGD(
            predictor.parameters(),
            momentum=0.9,
            lr=args.lr * args.bsz / 256,
            weight_decay=args.wd
        )

    # macros
    backbone = main_branch.encoder        
    projector = main_branch.projector
    
    param_print = 'encoder params: {}M'.format(count_parameters_in_MB(backbone))
    print(param_print)
    file_to_update.write(param_print + '\n')
    file_to_update.flush()

    # logging
    start = time.time()
    os.makedirs(args.path_dir, exist_ok=True)
    torch.save(dict(epoch=0, state_dict=main_branch.state_dict()), os.path.join(args.path_dir, '0.pth'))
    scaler = GradScaler()

    # training
    for e in range(1, args.epochs + 1):
        # declaring train
        main_branch.train()
        if args.loss == 'simsiam':
            predictor.train()

        # epoch
        for it, (inputs, y) in enumerate(train_loader, start=(e - 1) * len(train_loader)):
            # adjust
            lr = adjust_learning_rate(epochs=args.epochs,
                                      warmup_epochs=args.warmup_epochs,
                                      base_lr=args.lr * args.bsz / 256,
                                      optimizer=optimizer,
                                      loader=train_loader,
                                      step=it)
            # zero grad
            main_branch.zero_grad()
            if args.loss == 'simsiam':
                predictor.zero_grad()

            def forward_step():
                x1 = inputs[0].cuda()
                x2 = inputs[1].cuda()

                if args.guide:
                    eqv_logit1, b1 = backbone.extract_guided_output(x1)
                    eqv_logit2, b2 = backbone.extract_guided_output(x2)
                else:
                    b1 = backbone(x1)
                    b2 = backbone(x2)

                z1 = projector(b1)
                z2 = projector(b2)

                # forward pass
                if args.loss == 'simclr':
                    loss = info_nce_loss(z1, z2) / 2 + info_nce_loss(z2, z1) / 2
                    
                elif args.loss == 'simsiam':
                    p1 = predictor(z1)
                    p2 = predictor(z2)
                    loss = negative_cosine_similarity_loss(p1, z2) / 2 + negative_cosine_similarity_loss(p2, z1) / 2
                else:
                    raise
                    
                if args.guide:
                    ori_loss = cross_entropy_AB(eqv_logit1, eqv_logit2) / 2 + cross_entropy_AB(eqv_logit2, eqv_logit1) / 2
                    loss += args.beta * ori_loss

                return loss

            # optimization step
            if args.fp16:
                with autocast():
                    loss = forward_step()
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                if args.loss == 'simsiam':
                    scaler.step(pred_optimizer)

            else:
                loss = forward_step()
                loss.backward()
                optimizer.step()
                if args.loss == 'simsiam':
                    pred_optimizer.step()

        if args.fp16:
            with autocast():
                knn_acc = knn_loop(backbone, memory_loader, test_loader)
        else:
            knn_acc = knn_loop(backbone, memory_loader, test_loader)

        line_to_print = (
            f'epoch: {e} | knn_acc: {knn_acc:.3f} | '
            f'loss: {loss.item():.3f} | lr: {lr:.6f} | '
            f'time_elapsed: {time.time() - start:.3f}'
        )
        if file_to_update:
            file_to_update.write(line_to_print + '\n')
            file_to_update.flush()
        print(line_to_print)

        if e % args.save_every == 0:
            torch.save(dict(epoch=e, state_dict=main_branch.state_dict()),
                       os.path.join(args.path_dir, f'{e}.pth'))

    return main_branch.encoder, file_to_update


def eval_loop(encoder, file_to_update, ind=None):
    # dataset
    train_transform = T.Compose([
        T.RandomResizedCrop(33, interpolation=T.InterpolationMode.BICUBIC),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        normalize
    ])
    test_transform = T.Compose([
        T.Resize(36, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(33),
        T.ToTensor(),
        normalize
    ])

    train_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.CIFAR10('../data', train=True, transform=train_transform, download=True),
        shuffle=True,
        batch_size=256,
        pin_memory=True,
        num_workers=args.num_workers,
        drop_last=True
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.CIFAR10('../data', train=False, transform=test_transform, download=True),
        shuffle=False,
        batch_size=256,
        pin_memory=True,
        num_workers=args.num_workers
    )

    # classifier = nn.Linear(512, 10).cuda()
    classifier = nn.Linear(encoder.feature_dim, 10).cuda()

    # optimization
    optimizer = torch.optim.SGD(
        classifier.parameters(),
        momentum=0.9,
        lr=30,
        weight_decay=0
    )
    scaler = GradScaler()

    # training
    for e in range(1, 101):
        # declaring train
        classifier.train()
        encoder.eval()
        # epoch
        for it, (inputs, y) in enumerate(train_loader, start=(e - 1) * len(train_loader)):
            # adjust
            adjust_learning_rate(epochs=100,
                                 warmup_epochs=0,
                                 base_lr=30,
                                 optimizer=optimizer,
                                 loader=train_loader,
                                 step=it)
            # zero grad
            classifier.zero_grad()

            def forward_step():
                with torch.no_grad():
                    b = encoder(inputs.cuda())
                logits = classifier(b)
                loss = F.cross_entropy(logits, y.cuda())
                return loss

            # optimization step
            if args.fp16:
                with autocast():
                    loss = forward_step()
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss = forward_step()
                loss.backward()
                optimizer.step()

        if e % 10 == 0:
            accs = []
            classifier.eval()
            for idx, (images, labels) in enumerate(test_loader):
                with torch.no_grad():
                    if args.fp16:
                        with autocast():
                            b = encoder(images.cuda())
                            preds = classifier(b).argmax(dim=1)
                    else:
                        b = encoder(images.cuda())
                        preds = classifier(b).argmax(dim=1)
                    hits = (preds == labels.cuda()).sum().item()
                    accs.append(hits / b.shape[0])
            accuracy = np.mean(accs) * 100
            # final report of the accuracy
            line_to_print = (
                f'seed: {ind} | accuracy (%) @ epoch {e}: {accuracy:.2f}'
            )
            file_to_update.write(line_to_print + '\n')
            file_to_update.flush()
            print(line_to_print)

    return accuracy


def main(args):
    fix_seed(args.seed)
    encoder, file_to_update = ssl_loop(args)
    accs = []

    for i in range(5):
        eval_encoder = ESCNNEncoder(guide=args.guide, use_gpool=args.use_gpool, N=args.order).cuda()
        eval_encoder.load_state_dict(encoder.state_dict())
        accs.append(eval_loop(eval_encoder, file_to_update, i))
    line_to_print = f'aggregated linear probe: {np.mean(accs):.3f} +- {np.std(accs):.3f}'
    file_to_update.write(line_to_print + '\n')
    file_to_update.flush()
    print(line_to_print)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dim_proj', default='2048,2048', type=str)
    parser.add_argument('--dim_pred', default=512, type=int)
    parser.add_argument('--epochs', default=800, type=int)
    parser.add_argument('--lr', default=0.03, type=float)
    parser.add_argument('--bsz', default=512, type=int)
    parser.add_argument('--wd', default=0.0005, type=float)
    parser.add_argument('--loss', default='simclr', type=str, choices=['simclr', 'simsiam'])
    parser.add_argument('--save_every', default=400, type=int)
    parser.add_argument('--warmup_epochs', default=10, type=int)
    parser.add_argument('--path_dir', default='../experiment/tmp', type=str)
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--num_workers', default=16, type=int)
    parser.add_argument('--checkpoint_path', default=None, type=str)
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('--beta', default=0.0, type=float)
    parser.add_argument('--rotaug', action='store_true')
    parser.add_argument('--guide', action='store_true')
    parser.add_argument('--order', default=4, type=int)
    parser.add_argument('--use_gpool', action='store_true')

    args = parser.parse_args()

    main(args)
