import os

os.chdir("../")
data_dir = "../aistplusplus"
import sys
import yaml
import shutil
import math
import random
import numpy as np
import logging
from pathlib import Path

from src.data.dataset.loader import AISTDataset

from tqdm import tqdm
import pandas as pd
from src.data.dataset.cluster_misc import lexicon, get_names, genre_list, vidn_parse
from src.metrics import preprocess, metric_nmi, ngram_ent

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
from models.LSTM import LSTMClassifier


class LexiconDataset(Dataset):
    def __init__(self, exp, K, trval):
        official_loader = AISTDataset(os.path.join(data_dir, "annotations"))
        df = pd.read_pickle(f"./logs/{exp}/all_advanced_{trval}_{K}.pkl")
        lb = 1
        dfs = {_: list(x[x["length"] > lb]["word"]) for _, x in df.groupby("name") if len(x) >= 1}
        self.names = []
        self.data = []
        self.lengths = []
        for k, v in dfs.items():
            self.names.append(k)
            ldd = official_loader.load_keypoint3d(k)[:240]
            length = ldd.shape[0]
            self.data.append(torch.tensor(ldd.reshape(length, -1)).float())
            self.lengths.append(length)

        # # overfit testing
        # retain = 32
        # self.names, self.data, self.lengths = self.names[:retain], self.data[:retain], self.lengths[:retain]

        max_len = max(self.lengths)
        min_len=  min(self.lengths)
        mean_len = sum(self.lengths) / len(self)
        print(f"For {trval} (K={K}), Lexicon Dataeset is built with {len(self)} instances, max {max_len}, mean {mean_len}, min {min_len}")

    def __len__(self):
        return len(self.names)

    def __getitem__(self, item):
        # innately has some randomness, will crop a continuous chunk from a video of batch size length
        name = self.names[item]
        genre_here = name[:3]
        genre_idx = genre_list.index(genre_here)
        tbc = self.data[item]
        return tbc, genre_idx

def collate_here(lop):
    lengths = [_[0].shape[0] for _ in lop]
    labels = [_[1] for _ in lop]
    max_len = max(lengths)
    tbc = [_[0] for _ in lop]
    tbc = [torch.cat([_, _.new_zeros(max_len-l, 17*3)], dim=0) for _, l in zip(tbc, lengths)]
    tbc = torch.stack(tbc).transpose(1, 0)  # [seq_max_len, bs, 51]
    return tbc, torch.tensor(labels), lengths

def clip_gradient(model, clip_value):
    params = list(filter(lambda p: p.grad is not None, model.parameters()))
    for p in params:
        p.grad.data.clamp_(-clip_value, clip_value)


def train_model(model, train_iter):
    total_epoch_loss = 0
    total_epoch_acc = 0
    optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate, weight_decay=weight_decay)
    steps = 0
    model.train()
    for idx, batch in enumerate(train_iter):
        text = batch[0]
        target = batch[1].long()
        lengths = batch[2]
        if torch.cuda.is_available():
            text = text.cuda()
            target = target.cuda()
        optim.zero_grad()
        prediction = model(text, lengths)
        loss = loss_fn(prediction, target)
        num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).float().sum()
        acc = 100.0 * num_corrects / target.shape[0]
        loss.backward()
        clip_gradient(model, 1e-1)
        optim.step()
        steps += 1

        total_epoch_loss += loss.item()
        total_epoch_acc += acc.item()

    return total_epoch_loss / len(train_iter), total_epoch_acc / len(train_iter)


def eval_model(model, val_iter):
    total_epoch_loss = 0
    total_epoch_acc = 0
    model.eval()
    with torch.no_grad():
        for idx, batch in enumerate(val_iter):
            text = batch[0]
            target = batch[1].long()
            lengths = batch[2]
            if torch.cuda.is_available():
                text = text.cuda()
                target = target.cuda()
            prediction = model(text, lengths)
            loss = loss_fn(prediction, target)
            num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).sum()
            acc = 100.0 * num_corrects / target.shape[0]
            total_epoch_loss += loss.item()
            total_epoch_acc += acc.item()

    return total_epoch_loss / len(val_iter), total_epoch_acc / len(val_iter)


learning_rate = 1e-3
weight_decay = 0
batch_size = 50
output_size = 10
hidden_size = 256
embedding_length = 256
exp = "protect_rnn_1"
K = 240

model = LSTMClassifier(output_size, hidden_size, K+1, embedding_length, wordske="ske").cuda()
loss_fn = F.cross_entropy

tr_dataset = LexiconDataset(exp, K, trval="tr")
tr_size = len(tr_dataset) * 4 // 5
val_size = len(tr_dataset) - tr_size
train_set, val_set = torch.utils.data.random_split(tr_dataset, [tr_size, val_size])
train_iter = DataLoader(train_set, collate_fn=collate_here, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=False)
val_iter = DataLoader(val_set, collate_fn=collate_here, batch_size=val_size, shuffle=True, num_workers=8, drop_last=False)

test_dataset = LexiconDataset(exp, K, trval="val")
test_iter = DataLoader(test_dataset, collate_fn=collate_here, batch_size=len(test_dataset), shuffle=True, drop_last=False, num_workers=8)

best_val_acc = -1
for epoch in range(200):
    train_loss, train_acc = train_model(model, train_iter)
    val_loss, val_acc = eval_model(model, val_iter)
    if val_acc >= best_val_acc:
        test_loss, test_acc = eval_model(model, test_iter)
        best_val_acc = val_acc

    if epoch % 10 == 0:
        print(f'Epoch: {epoch + 1:03}, Train Loss: {train_loss:.7f}, '
              f'Train Acc: {train_acc:.2f}%, Val. Loss: {val_loss:3f}, '
              f'Val. Acc: {val_acc:.2f}%')

tbd = f'For K={K}, Test Loss: {test_loss:3f}; Test. Acc: {test_acc:.2f}%'
print(tbd)