import torch
from PIL import Image
import numpy as np
import os
import pandas as pd
from tqdm import tqdm
import glob
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedShuffleSplit
import matplotlib.pyplot as plt
from main.clip_models.baseline import get_transforms, initialize_model

device = f'cuda:{0}'


class SmidDataset(Dataset):
    def __init__(self, preprocess, train=True, verbose=False, transform=None):
        self.preprocess = preprocess

        data_set_path = '/workspace/datasets/SMID_images_400px/'
        df = pd.read_csv(os.path.join(data_set_path, 'SMID_norms.csv'), sep=',', header=0)
        valence_means = df['valence_mean'].values
        moral_means = df['moral_mean'].values
        img_paths = []
        img_labels = []
        label_weights = [0, 0]
        for idx, image_name in enumerate(tqdm(df['img_name'].values)):
            image_path = os.path.join(data_set_path, 'img', image_name)
            image_path = glob.glob(image_path + '.*')[0]
            # valence_means[idx]
            # moral_means[idx]
            if moral_means[idx] < 2.5:
                img_labels.append(1)
                label_weights[0] += 1
            else:
                img_labels.append(0)
                label_weights[1] += 1
            if verbose:
                input_text = input('Press enter to show img')
                if input_text == '':
                    img = Image.open(image_path)
                    plt.imshow(img)
                    plt.title(f'Moral mean {moral_means[idx]:.3f}\nValence mean {valence_means[idx]:.3f}')
                    plt.axis('off')
                    plt.show()
                    plt.close()
            img_paths.append(image_path)

        label_weights[0] /= len(img_paths)
        label_weights[1] /= len(img_paths)
        self.label_weights = label_weights
        print('label_weights', label_weights)
        img_paths = np.array(img_paths)
        img_labels = np.array(img_labels)

        # splits = StratifiedShuffleSplit(test_size=0.2, n_splits=1, random_state=42)
        # train_idx, test_idx = next(iter(splits.split(img_paths, img_labels)))
        # indices = train_idx if train else test_idx
        indices = range(len(img_paths))
        self.img_paths = img_paths[indices]
        self.imgs = [Image.open(img_path) for img_path in img_paths]
        self.img_labels = img_labels[indices]
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        # image = self.preprocess(Image.open(self.img_paths[idx]))
        image = self.imgs[idx]  # self.preprocess(Image.open(self.img_paths[idx]))
        label = self.img_labels[idx]

        if self.transform:
            image = self.transform(image)

        #exit()
        return image, label


def accuracy(y_pred, y_test, cm):
    y_pred_tag = torch.argmax(y_pred, dim=-1)
    cm[0][1] += (y_pred_tag[y_test == 0] == 1).sum().int().item()
    cm[0][0] += (y_pred_tag[y_test == 0] == 0).sum().int().item()
    cm[1][1] += (y_pred_tag[y_test == 1] == 1).sum().int().item()
    cm[1][0] += (y_pred_tag[y_test == 1] == 0).sum().int().item()
    correct_results_sum = (y_pred_tag == y_test).sum().float()
    acc = correct_results_sum / y_test.shape[0]
    acc = acc * 100
    return acc


def setup_dataset(model, transform_train, transform_test, verbose=False):
    training_data = SmidDataset(preprocess=None, train=True, verbose=verbose, transform=transform_train)
    test_data = SmidDataset(preprocess=None, train=False, transform=transform_test)

    train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False, drop_last=False)

    return train_dataloader, test_dataloader


def setup_model():
    model, input_size = initialize_model(2, True, use_pretrained=True)
    transform_train, transform_test = get_transforms(input_size)
    model.to(device)
    return model, transform_train, transform_test


def train(train_dataloader, test_dataloader, model, optimizer, epochs=5):
    criterion = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor(train_dataloader.dataset.label_weights).to(device))

    print("Eval before training")
    epoch_loss = 0
    epoch_acc = 0
    epoch_cmt = [[0, 0], [0, 0]]
    with torch.no_grad():
        for X_batch, y_batch in tqdm(test_dataloader):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            # print('Prompts shape', model.prompts.shape)
            # print('Prompts', model.prompts.data)

            logits = model(X_batch)
            # print('Logits', logits)
            loss = criterion(logits, y_batch)
            y_pred = logits.softmax(dim=-1)
            acc = accuracy(y_pred, y_batch, epoch_cmt)
            # print(loss)
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        print(
            f'Test: | Loss: {epoch_loss / len(train_dataloader):.5f} | Acc: {epoch_acc / len(train_dataloader):.3f}')
        print(epoch_cmt)

    print("Training")

    for e in range(1, epochs + 1):
        epoch_loss = 0
        epoch_acc = 0
        epoch_cmt = [[0, 0], [0, 0]]
        for X_batch, y_batch in tqdm(train_dataloader):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            # print('Prompts shape', model.prompts.shape)
            # print('Prompts', model.prompts.data)

            logits = model(X_batch)
            # print('Logits', logits)
            loss = criterion(logits, y_batch)
            loss.backward()
            y_pred = logits.softmax(dim=-1)
            acc = accuracy(y_pred, y_batch, epoch_cmt)
            optimizer.step()
            # print(loss)
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        print(
            f'Epoch {e + 0:03}: | Loss: {epoch_loss / len(train_dataloader):.5f} | Acc: {epoch_acc / len(train_dataloader):.3f}')
        print(epoch_cmt)

    print("Eval after training")
    with torch.no_grad():
        epoch_loss = 0
        epoch_acc = 0
        epoch_cmt = [[0, 0], [0, 0]]
        for X_batch, y_batch in tqdm(test_dataloader):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            # print('Prompts shape', model.prompts.shape)
            # print('Prompts', model.prompts.data)

            logits = model(X_batch)
            # print('Logits', logits)
            loss = criterion(logits, y_batch)
            y_pred = logits.softmax(dim=-1)
            acc = accuracy(y_pred, y_batch, epoch_cmt)
            # print(loss)
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        print(
            f'Test: | Loss: {epoch_loss / len(train_dataloader):.5f} | Acc: {epoch_acc / len(train_dataloader):.3f}')
        print(epoch_cmt)


def main():
    torch.random.manual_seed(1)
    model, transform_train, transform_test = setup_model()
    train_dataloader, test_dataloader = setup_dataset(model, transform_train, transform_test, verbose=False)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    train(train_dataloader, test_dataloader, model, optimizer, epochs=100)


if __name__ == '__main__':
    main()
