import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn as nn
from tqdm import tqdm
import torchvision.models.video
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset
from models.r50lstm import Resnet50LSTM
from data_preprocessing import create_video_dataset
def train(cmd,device,load):

    if load==True:
        if cmd == 'R2p1d':
            model =nn.Sequential(
                    torchvision.models.video.r2plus1d_18(),
                    nn.ReLU(),
                    nn.Linear(400,128),
                    nn.ReLU(),
                    nn.Linear(128,28)).to(device)



        elif cmd == 'R50LSTM':
            model = Resnet50LSTM(num_classes=28, drop_rate=0.1, fc_precede_mean=True,
                                 resnet_feat_type='layer4').to(device)


        model.load_state_dict(torch.load(f'Model_pt/{cmd}.pth', map_location=device))
        return model
    else:
        dataset = np.load("dataset/dataset.npy")
        dataset_label = np.load("dataset/dataset_label.npy")
        X_train, X_test, Y_train, Y_test = train_test_split(dataset, dataset_label, test_size=0.3, random_state=41)
        X_test, X_val, Y_test, Y_val = train_test_split(X_test, Y_test, test_size=0.3, random_state=41)

        train_data = torch.FloatTensor(X_train)
        train_label = torch.LongTensor(Y_train)

        val_data = torch.FloatTensor(X_val)
        val_label = torch.LongTensor(Y_val)

        if cmd == 'R50LSTM':
            batch_size=16
        else:
            batch_size = 8

        ds = TensorDataset(train_data, train_label)
        data_loader = torch.utils.data.DataLoader(dataset=ds,
                                                  batch_size=batch_size,
                                                  shuffle=True,
                                                  drop_last=True)


        val_ds = TensorDataset(val_data, val_label)
        val_data_loader = torch.utils.data.DataLoader(dataset=val_ds,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      drop_last=True)

        divv = len(val_data)
        total_batch = len(data_loader)

        del train_data
        del val_data
        del X_train
        del Y_train
        del X_val
        del Y_val
        del dataset
        del dataset_label

        if cmd == 'R2p1d':
            model = nn.Sequential(
                torchvision.models.video.r2plus1d_18(pretrained=True),
                nn.ReLU(),
                nn.Linear(400, 128),
                nn.ReLU(),
                nn.Linear(128, 28)).to(device)
            learning_rate = 1e-3
            total_epoch = 50

        elif cmd=='R50LSTM':
            model = Resnet50LSTM(num_classes=28, drop_rate=0.1, fc_precede_mean=True,
                                 resnet_feat_type='layer4').to(device)
            learning_rate = 5e-5
            total_epoch = 100

        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        criterion = torch.nn.CrossEntropyLoss().to(device)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.5)


        total_loss = []
        val_acc = []
        cntt = 0
        for epoch in tqdm(range(total_epoch)):
            avg_cost = 0

            if cmd=='R50LSTM':
                model.train()

            for X, Y in tqdm(data_loader):

                optimizer.zero_grad()
                if cntt == 0:
                    cntt += 1

                prediction = model(X.permute(0, 4, 1, 2, 3).to(device))
                loss = criterion(prediction, Y.to(device))
                loss.backward()
                optimizer.step()
                avg_cost += loss.item() / total_batch

            scheduler.step()

            if cmd=='R50LSTM':
                model.eval()

            cnt = 0
            for X, Y in val_data_loader:
                with torch.no_grad():
                    pred = model(X.permute(0, 4, 1, 2, 3).to(device))
                    correction = pred.argmax(axis=1).cpu() == Y
                    cnt += sum(correction)

            test_acc = (cnt / divv) * 100

            print('[Epoch: {}] cost = {} '.format(epoch + 1, avg_cost), end="")
            print(f"Val acc:{test_acc}%", end=" ")

            val_acc.append((cnt / divv))
            total_loss.append(avg_cost)

            if epoch % 5 == 0 and epoch != 0:
                torch.save(model.state_dict(), f"Model_pt/epoch{epoch}_{cmd}.pth")

        plt.plot(total_loss, label="Train")
        plt.legend()
        plt.ylabel("Loss")
        plt.xlabel("Epochs")
        plt.show()

        test_ds = TensorDataset(torch.FloatTensor(X_test), torch.LongTensor(Y_test))
        test_data_loader = torch.utils.data.DataLoader(dataset=test_ds,batch_size=batch_size,shuffle=True, drop_last=True)

        cnt = 0
        if cmd == 'R50LSTM':
            model.eval()
        for X, Y in tqdm(test_data_loader):
            with torch.no_grad():
                pred = model(X.permute(0, 4, 1, 2, 3).to(device))
                correction = pred.argmax(axis=1).cpu() == Y
                cnt += sum(correction)

        test_acc = (cnt / len(X_test)) * 100
        print(f"Test data Model acc:{test_acc}%")

        return model


