import gc
import random
import argparse
import warnings
import numpy as np
import pandas as pd

from tqdm import tqdm
from typing import ClassVar

import torch
import torch.nn as nn
import torch.utils.data as Data

from model import IndCDE
from utils import VideoDataset

warnings.filterwarnings('ignore')

def load_data(path: str, batch_size: int = 1):
    train_data_loader = Data.DataLoader(
        VideoDataset(path, True, 0.8),
        batch_size=batch_size,
        shuffle=True
    )
    test_data_loader = Data.DataLoader(
        VideoDataset(path, False, 0.8),
        batch_size=batch_size
    )

    return train_data_loader, test_data_loader

def main(arg):
    random.seed(arg.seed)
    np.random.seed(arg.seed)
    torch.manual_seed(arg.seed)
    torch.cuda.manual_seed(arg.seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

    train_loader, test_loader = load_data(arg.dataset, arg.batch_size)

    print(arg.device)

    model = arg.model(*arg.model_para).to(arg.device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = arg.num_epochs * len(train_loader), eta_min = 0.00005, last_epoch = -1)

    criterion = nn.CrossEntropyLoss()

    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    if arg.model.__name__.lower().__contains__('cde') or arg.model.__name__.lower().__contains__('ind'):
        print('torch.arange')
        input_t = torch.arange
    else:
        print('torch.ones')
        input_t = torch.ones

    version = arg.model.__name__ + '_' + str(arg.model_para)
    for epoch in range(arg.num_epochs):
        print(version, 'Epoch {}/{}'.format(epoch + 1, arg.num_epochs))
        epoch_loss = 0
        epoch_corrects = 0
        num_sample = 0
        model.train()
        for x, y in tqdm(train_loader):
            t = input_t(x.size(1), device=arg.device).expand(x.size(0), -1)

            x, y = x.to(arg.device), y.to(arg.device)
            output = model(t, x)
            loss = criterion(output, y)
            loss.backward()

            epoch_corrects += int(torch.sum(torch.argmax(output, dim=-1) == y))
            epoch_loss += loss.item() * x.size(0)
            num_sample += x.size(0)

            nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1, norm_type=2)
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

        train_loss.append(epoch_loss / num_sample)
        train_acc.append(epoch_corrects / num_sample)
        print(' ', train_loss[-1], train_acc[-1])

        epoch_loss = 0
        epoch_corrects = 0
        num_sample = 0
        model.eval()
        with torch.no_grad():
            for x, y in tqdm(test_loader):
                t = input_t(x.size(1), device=arg.device).expand(x.size(0), -1)
                x, y = x.to(arg.device), y.to(arg.device)
                output = model(t, x)
                loss = criterion(output, y)

                epoch_corrects += int(torch.sum(torch.argmax(output, dim=-1) == y))
                epoch_loss += loss.item() * x.size(0)
                num_sample += x.size(0)

        test_loss.append(epoch_loss / num_sample)
        test_acc.append(epoch_corrects / num_sample)
        print(' ', test_loss[-1], test_acc[-1])

        torch.save(model, f'{version}.pkl')

        try:
            pd.DataFrame({'Train Loss': train_loss, 'Train Acc': train_acc}).to_csv(f'{version}_Train.csv')
        except:
            print('Fail to save the file Train.csv')
            pd.DataFrame({'Train Loss': train_loss, 'Train Acc': train_acc}).to_csv(f'{version}_Train_1.csv')

        try:
            pd.DataFrame({'Test Loss': test_loss, 'Test Acc': test_acc}).to_csv(f'{version}_Test.csv')
        except:
            print('Fail to save the file Test.csv')
            pd.DataFrame({'Test Loss': test_loss, 'Test Acc': test_acc}).to_csv(f'{version}_Test_1.csv')

    gc.collect()

if __name__ == '__main__':
    torch.set_num_threads(1)
    n_class = 3
    for hidden_size in [32, 64]:
        parser = argparse.ArgumentParser()
        parser.add_argument('--seed', type = int, default = 42)
        parser.add_argument('--batch_size', type = int, default = 4)
        parser.add_argument('--device', default = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'))

        parser.add_argument('--model', type = ClassVar, default = IndCDE)
        parser.add_argument('--model_para', type = tuple, default = (64, hidden_size, n_class, 6))

        parser.add_argument('--dataset', type = str, default = 'data')
        parser.add_argument('--num_epochs', type = int, default = 50)

        main(arg=parser.parse_args([]))