import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

import time
import datetime
import argparse
import numpy as np
from pathlib import Path
from collections import Counter

from utils import *
from losses import SupConLoss

class ToySupCon(nn.Module):
    def __init__(self):
        super(ToySupCon, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(2, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 2),
        )

    def forward(self, x):
        out = F.normalize(self.encoder(x), dim=1)
        return out

class LinearClassifier(nn.Module):
    def __init__(self):
        super(LinearClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.ReLU(),
            nn.Linear(2, 4),
        )

    def forward(self, x):
        out = self.classifier(x)
        return out


parser = argparse.ArgumentParser(description='ToySupCon Task - A Simple Binary Image Classification')
parser.add_argument('--gpu-id', default=2, type=int)
parser.add_argument('--seed-num', default=1, type=int)
parser.add_argument('--save', action='store_true', default=False)
parser.add_argument('--res-dir', default='result', type=str)
parser.add_argument('--res-tag', default='LT_toy_supcon', type=str)

hyper = parser.add_argument_group('params')
hyper.add_argument('--n_epochs', default=100, type=int)
hyper.add_argument('--lr', default=1e-2, type=float)
hyper.add_argument('--batch-size', default=16, type=int)
hyper.add_argument('--encoding', action='store_true', default=False)
hyper.add_argument('--ckpt', type=str)

args = parser.parse_args()

set_random_seed(seed_num=args.seed_num)
device = torch.device('cuda:{}'.format(args.gpu_id))

num_points_tr = [1000, 500, 100, 10]
num_points_ts = [200, 200, 200, 200]

tr_x, tr_y = [], []
ts_x, ts_y = [], []

loc = [
    torch.tensor([[-3, 3]]),
    torch.tensor([[3, 3]]),
    torch.tensor([[3, -3]]),
    torch.tensor([[-3, -3]]),
]

for idx in range(len(num_points_tr)):
    tmp = torch.normal(mean=torch.zeros(num_points_tr[idx], 2),
                       std=torch.ones(num_points_tr[idx], 2))
    tr_x.append(tmp+loc[idx])
    tr_y.append(torch.ones(num_points_tr[idx])*idx)

for idx in range(len(num_points_ts)):
    tmp = torch.normal(mean=torch.zeros(num_points_ts[idx], 2),
                       std=torch.ones(num_points_ts[idx], 2))
    ts_x.append(tmp+loc[idx])
    ts_y.append(torch.ones(num_points_ts[idx])*idx)

tr_x = torch.vstack(tr_x)
tr_y = torch.cat(tr_y).long()

ts_x = torch.vstack(ts_x)
ts_y = torch.cat(ts_y).long()

trainset = list(zip(tr_x, tr_y))
trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True)
testset = list(zip(ts_x, ts_y))
testloader = DataLoader(testset, batch_size=args.batch_size, shuffle=False)

if args.encoding:
    net = ToySupCon().to(device)
    criterion = SupConLoss(args)
else:
    encoder = ToySupCon()
    encoder.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
    encoder = encoder.encoder.to(device)
    encoder.eval()
    net = LinearClassifier().to(device)
    criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(), lr=args.lr)

def supcon_train(dataloader, epoch):
    net.train()
    tr_loss = 0.

    prev_time = time.time()
    for idx, (data, targets) in enumerate(dataloader):
        data, targets = data.to(device), targets.to(device)
        data = torch.cat([data, data], dim=0)

        bsz = targets.shape[0]

        optimizer.zero_grad()

        feat = net(data)
        f1, f2 = torch.split(feat, [bsz, bsz], dim=0)
        feat = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)

        loss = criterion(feat, targets)
        tr_loss += loss.item()
        loss.backward()

        optimizer.step()

        # verbose
        batches_done = (epoch - 1) * len(dataloader) + idx
        batches_left = args.n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()
        print("\r[epoch {:3d}/{:3d}] [batch {:4d}/{:4d}] loss: {:.6f} (eta: {})".format(
            epoch, args.n_epochs, idx+1, len(dataloader), loss, time_left), end=' ')
    
    cnt = len(dataloader.dataset)
    tr_loss /= cnt
    
    return tr_loss, 0.0


def train(dataloader, epoch):
    net.train()
    tr_loss = 0.
    correct = 0

    prev_time = time.time()
    for idx, (data, targets) in enumerate(dataloader):
        data, targets = data.to(device), targets.to(device)

        optimizer.zero_grad()

        feat = encoder(data)
        output = net(feat)

        loss = criterion(output, targets)
        tr_loss += loss.item()
        loss.backward()

        optimizer.step()

        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(targets.view_as(pred)).sum().item()

        # verbose
        batches_done = (epoch - 1) * len(dataloader) + idx
        batches_left = args.n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()
        print("\r[epoch {:3d}/{:3d}] [batch {:4d}/{:4d}] loss: {:.6f} (eta: {})".format(
            epoch, args.n_epochs, idx+1, len(dataloader), loss, time_left), end=' ')
    
    cnt = len(dataloader.dataset)
    tr_loss /= cnt
    tr_acc = correct / cnt
    
    return tr_loss, tr_acc

    
def test(dataloader):
    net.eval()
    ts_loss = 0.
    correct = 0

    with torch.no_grad():
        for idx, (data, targets) in enumerate(dataloader):
            data, targets = data.to(device), targets.to(device)

            feat = encoder(data)
            output = net(feat)

            loss = criterion(output, targets)
            ts_loss += loss.item()

            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(targets.view_as(pred)).sum().item()
    
    cnt = len(dataloader.dataset)
    ts_loss /= cnt
    ts_acc = correct / cnt
    
    return ts_loss, ts_acc


if args.save:
    result = []

tr_acc = 0.
for epoch in range(1, args.n_epochs + 1):
    if args.encoding:
        tr_loss, tr_acc = supcon_train(trainloader, epoch)
    else:
        tr_loss, tr_acc = train(trainloader, epoch)
        ts_loss, ts_acc = test(testloader)
        print("loss: {:.4f}, acc: {:.4f} ".format(tr_loss, tr_acc)
             +"ts_loss: {:.4f}, ts_acc: {:.4f} ".format(ts_loss, ts_acc), end='')
        if args.save:
            result.append([tr_loss, tr_acc, ts_loss, ts_acc])
print()


def val(dataset):
    net.eval()
    confid = []
    features = []

    with torch.no_grad():
        for idx, (data, targets) in enumerate(dataset):
            data, targets = data.to(device), targets.to(device)
            data = data.reshape(1, -1)

            feat = encoder(data)
            output = net(feat)

            confid.append(F.softmax(output, dim=-1).detach().cpu().numpy())
            features.append(feat.detach().cpu().numpy())
    
    return np.vstack(confid), np.vstack(features)


if args.save:
    Path(args.res_dir).mkdir(parents=True, exist_ok=True)
    if args.encoding:
        ckpt_filepath = Path(args.res_dir) / '{}_model.pth'.format(args.res_tag)
        torch.save(net.state_dict(), ckpt_filepath)
    else:
        # dataset
        tr_filepath = Path(args.res_dir) / '{}_dataset_tr.pth'.format(args.res_tag)
        ts_filepath = Path(args.res_dir) / '{}_dataset_ts.pth'.format(args.res_tag)
        torch.save(trainset, tr_filepath)
        torch.save(testset, ts_filepath)
        # confidence & actv
        confid_tr_filepath = Path(args.res_dir) / '{}_confidence_tr.npy'.format(args.res_tag)
        confid_ts_filepath = Path(args.res_dir) / '{}_confidence_ts.npy'.format(args.res_tag)
        actv_tr_filepath = Path(args.res_dir) / '{}_actv_tr.npy'.format(args.res_tag)
        actv_ts_filepath = Path(args.res_dir) / '{}_actv_ts.npy'.format(args.res_tag)
        confid_tr, actv_tr = val(trainset)
        confid_ts, actv_ts = val(testset)
        np.save(confid_tr_filepath, confid_tr)
        np.save(confid_ts_filepath, confid_ts)
        np.save(actv_tr_filepath, actv_tr)
        np.save(actv_ts_filepath, actv_ts)
        # model
        model_filepath = Path(args.res_dir) / '{}_model.pth'.format(args.res_tag)
        torch.save(net.state_dict(), model_filepath)
        # acc
        acc_filepath = Path(args.res_dir) / '{}_acc.npy'.format(args.res_tag)
        np.save(acc_filepath, np.vstack(result))

