from __future__ import print_function

import collections
import pandas as pd
import torch
import torch.nn as nn
import torchtext
import torchvision
totensor = lambda dtype: lambda x: torch.tensor(x, dtype=dtype)
from IPython import embed
from utils.dataloader import *
import random

def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func
    
def vocab_func(vocab):
    def func(tok_iter):
        return [vocab[tok] for tok in tok_iter]
    return func
     
class TextClassificationDataset(torch.utils.data.Dataset):
    def __init__(self, vocab, raw_data, transforms):
        super(TextClassificationDataset, self).__init__()
        self._data = [(transforms[1](i[1]), transforms[0](i[0])) for i in raw_data]
        self._vocab = vocab
        
    def __getitem__(self, i):
        return self._data[i]

    def __len__(self):
        return len(self._labels)

    def __iter__(self):
        for x in self._data:
            yield x

    def get_vocab(self):
        return self._vocab
        
class Tokenizer:
    def __init__(self, tokenize_fn='basic_english', lower=True, max_length=None):
        self.tokenize_fn = torchtext.data.utils.get_tokenizer(tokenize_fn)
        self.lower = lower
        self.max_length = max_length

    def tokenize(self, s):
        tokens = self.tokenize_fn(s)
        if self.lower:
            tokens = [token.lower() for token in tokens]
        if self.max_length is not None:
            tokens = tokens[:self.max_length]
            paddedTokens = ['<pad>'] * self.max_length
            paddedTokens[:len(tokens)] = tokens
            tokens = paddedTokens
        return tokens

def build_vocab_from_data(raw_data, tokenizer, **vocab_kwargs):
    token_freqs = collections.Counter()
    for s in raw_data:
        label, text = s.label, s.text
        token_freqs.update(text)
    vocab = torchtext.vocab.Vocab(token_freqs, **vocab_kwargs)
    return vocab

def process_raw_data(raw_data, tokenizer, vocab):
    raw_data = [(s.label, s.text) for s in raw_data]
    text_transform = sequential_transforms(vocab_func(vocab),
                                           totensor(dtype=torch.long))
    label_transform = sequential_transforms(lambda x: LABEL.vocab.stoi[x], totensor(dtype=torch.long))
    transforms = (label_transform, text_transform)
    dataset = TextClassificationDataset(vocab, raw_data, transforms)
    return dataset

class Collator:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def collate(self, batch):
        labels, text = zip(*batch)
        print(labels)
        print(text)
        labels = torch.LongTensor(labels)
        lengths = torch.LongTensor([len(x) for x in text])
        text = nn.utils.rnn.pad_sequence(text, padding_value=self.pad_idx)
        return labels, text, lengths

def initialize_parameters(m):
    if isinstance(m, nn.Embedding):
        nn.init.uniform_(m.weight, -0.05, 0.05)
    elif isinstance(m, nn.GRU):
        for n, p in m.named_parameters():
            if 'weight_ih' in n:
                r, z, n = p.chunk(3)
                nn.init.xavier_uniform_(r)
                nn.init.xavier_uniform_(z)
                nn.init.xavier_uniform_(n)
            elif 'weight_hh' in n:
                r, z, n = p.chunk(3)
                nn.init.orthogonal_(r)
                nn.init.orthogonal_(z)
                nn.init.orthogonal_(n)
            elif 'bias' in n:
                r, z, n = p.chunk(3)
                nn.init.zeros_(r)
                nn.init.zeros_(z)
                nn.init.zeros_(n)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.zeros_(m.bias)

def get_pretrained_embedding(initial_embedding, pretrained_vectors, vocab, unk_token):
    pretrained_embedding = torch.FloatTensor(initial_embedding.weight.clone()).detach()
    unk_tokens = []
    for i in range(initial_embedding.weight.size(0)):
        pretrained_embedding[i] = pretrained_vectors.vectors[pretrained_vectors.stoi.get(vocab.itos[i], 0)]
    return pretrained_embedding, unk_tokens

def loadPretrainEmbedding(embedding):
    print("Initialize the embedding")
    fasttext = torchtext.vocab.FastText('simple')
    unk_token = '<unk>'
    pad_token = '<pad>'
    pad_idx = vocab[pad_token]
    embedding, unk_tokens = get_pretrained_embedding(embedding, fasttext, vocab, unk_token)
    return embedding

def get_data(dataset_name):
    global max_length, max_size, train_size, tokenizer, TEXT, LABEL, vocab, train_data, test_data
    max_length = 200
    max_size = 25000
    tokenizer = Tokenizer(max_length=max_length)
    TEXT = torchtext.data.Field(sequential=True, tokenize=tokenizer.tokenize, lower=True, include_lengths=True, batch_first=True, fix_length=max_length)
    LABEL = torchtext.data.LabelField()
    if dataset_name == 'sst':
        raw_train_data, raw_test_data = torchtext.datasets.IMDB.splits(TEXT, LABEL)
        raw_train_data.examples = []
        for line in open('.data/sst/train.csv'):
            ex = torchtext.data.example.Example()
            ex.label, ex.text = line.strip()[-1], tokenizer.tokenize(line.strip()[:-2])
            raw_train_data.examples.append(ex)       
        raw_test_data.examples = []
        for line in open('.data/sst/test.csv'):
            ex = torchtext.data.example.Example()
            ex.label, ex.text = line.strip()[-1], tokenizer.tokenize(line.strip()[:-2])
            raw_test_data.examples.append(ex)
    elif dataset_name == 'agnews':
        # download agnews
        # torchtext.datasets.AG_NEWS()
        raw_train_data, raw_test_data = torchtext.datasets.IMDB.splits(TEXT, LABEL)
        raw_train_data.examples = []
        train_csv = pd.read_csv('.data/ag_news_csv/train.csv', header=None)
        N = train_csv.shape[0]
        for i in range(N):
            ex = torchtext.data.example.Example()
            ex.label, ex.text = train_csv[0][i], tokenizer.tokenize(train_csv[1][i] + " @ " + train_csv[2][i])
            raw_train_data.examples.append(ex)       
        raw_test_data.examples = []
        test_csv = pd.read_csv('.data/ag_news_csv/test.csv', header=None)
        N = test_csv.shape[0]
        for i in range(N):
            ex = torchtext.data.example.Example()
            ex.label, ex.text = test_csv[0][i], tokenizer.tokenize(test_csv[1][i] + " @ " + test_csv[2][i])
            raw_test_data.examples.append(ex)
    elif dataset_name == 'amazon':
        raw_train_data, raw_test_data = torchtext.datasets.IMDB.splits(TEXT, LABEL)
        raw_train_data.examples = []
        for line in open('.data/amazon_5w/train.csv'):
            ex = torchtext.data.example.Example()
            ex.label, ex.text = line.strip()[-1], tokenizer.tokenize(line.strip()[:-2])
            raw_train_data.examples.append(ex)       
        raw_test_data.examples = []
        for line in open('.data/amazon_5w/test.csv'):
            ex = torchtext.data.example.Example()
            ex.label, ex.text = line.strip()[-1], tokenizer.tokenize(line.strip()[:-2])
            raw_test_data.examples.append(ex)
    elif dataset_name == 'imdb':
        raw_train_data, raw_test_data = torchtext.datasets.IMDB.splits(TEXT, LABEL)
    else:
        assert(False)
    LABEL.build_vocab(raw_train_data)
    vocab = build_vocab_from_data(raw_train_data, tokenizer, max_size=max_size)
    train_data = process_raw_data(raw_train_data, tokenizer, vocab)
    test_data = process_raw_data(raw_test_data, tokenizer, vocab)
    return train_data, test_data, vocab

class SentConvNet(nn.Module):
    def __init__(self, args, vocab_size, embedding_size, num_filters, num_classes, pad_idx):
        super(SentConvNet, self).__init__()
        self.args = args
        self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=pad_idx)
        self.conv3 = nn.Sequential(nn.Conv1d(embedding_size, num_filters, 3, padding=2), nn.ReLU(), nn.AdaptiveMaxPool1d(1))
        self.conv4 = nn.Sequential(nn.Conv1d(embedding_size, num_filters, 4, padding=3), nn.ReLU(), nn.AdaptiveMaxPool1d(1))
        self.conv5 = nn.Sequential(nn.Conv1d(embedding_size, num_filters, 5, padding=4), nn.ReLU(), nn.AdaptiveMaxPool1d(1))
        self.fc1 = nn.Linear(3 * num_filters, 100)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(100, num_classes)

    def forward(self, x):  # pylint: disable=arguments-differ
        x = self.embedding(x).transpose(1, 2)
        x = torch.cat([self.conv3(x), self.conv4(x), self.conv5(x)], dim=1).squeeze(2)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class Bi_GRU(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, output_dim, pad_idx):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=pad_idx)
        self.gru = nn.GRU(emb_dim, hid_dim, num_layers=1,batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(2*hid_dim, output_dim)

    def forward(self, text):
        # text = [batch size, seq len]
        embedded = self.embedding(text)

        # embedded = [batch size, seq len, emb dim]
        output, hidden = self.gru(embedded)

        prediction = self.fc(self.dropout(torch.cat([hidden[0], hidden[1]], dim=-1)))

        # prediction = [batch size, output dim]
        return prediction

class Bi_LSTM(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, output_dim, pad_idx):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(emb_dim, hid_dim, num_layers=1,batch_first=True, bidirectional=True) 
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(2*hid_dim, output_dim)

    def forward(self, text):
        # text = [batch size, seq len]
        embedded = self.embedding(text)

        # embedded = [batch size, seq len, emb dim]
        output, (hidden, _) = self.lstm(embedded)

        prediction = self.fc(self.dropout(torch.cat([hidden[0], hidden[1]], dim=-1)))

        # prediction = [batch size, output dim]
        return prediction

def Model(args):
    input_dim = max_size + 2  # hard code for IMDb
    emb_dim = 300
    hid_dim = 256
    output_dim = len(LABEL.vocab.itos)
    pad_token = '<pad>'
    pad_idx = vocab[pad_token]
    if args.model == 'CNN':
        model = SentConvNet(None, input_dim, emb_dim, hid_dim, output_dim, pad_idx)
    elif args.model == 'GRU':
        model = Bi_GRU(input_dim, emb_dim, hid_dim, output_dim, pad_idx)
    elif args.model == 'LSTM':
        model = Bi_LSTM(input_dim, emb_dim, hid_dim, output_dim, pad_idx)
    else:
        assert(False)
    #model.apply(initialize_parameters)
    if False:
        pretrained_embedding = loadPretrainEmbedding(model.embedding)
        model.embedding.weight.data.copy_(pretrained_embedding)
    #model.embedding.weight.require_grad = False
    return model

class IMDB(torch.utils.data.Dataset):
    def __init__(self, train=True):
        self.dataset = train_data if train == True else test_data
        self.targets = [sample[1] for sample in self.dataset]
        self.vocab = vocab
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.targets)

    def __getitem__(self, i):
        sample = (self.dataset[i][0], self.dataset[i][1])
        return sample

def getDataset(train=True):
    dataset = IMDB(train)
    return dataset

def basic_loader(num_clients, batch_size, loader_type):
    dataset = getDataset(train=True)
    if isinstance(loader_type, str):
        if loader_type == 'dirichlet-0.9':
            return dirichletLoader(num_clients, dataset, alpha=0.9, bsz=batch_size)
        else:
            assert(False)
    else:
        return loader_type(num_clients, dataset, bsz=batch_size)

def get_dataloader(args):
    train_data, test_data, vocab = get_data(args.dataset)
    trainData = train_dataloader(args.num_clients, args.batch_size, loader_type=args.loader_type, path=args.loader_path)
    testData = test_dataloader(args.test_batch_size)
    return trainData, testData, vocab
    
def train_dataloader(num_clients, batch_size, loader_type='iid', store=True, path=None):
    assert loader_type in ['iid', 'byLabel',
                           'dirichlet',
                           'dirichlet-0.9'], 'Loader has to be one of the  \'iid\',\'byLabel\',\'dirichlet\''
    if loader_type == 'iid':
        loader_type = iidLoader
    elif loader_type == 'byLabel':
        loader_type = byLabelLoader
    elif loader_type == 'dirichlet-0.9':
        loader_type = 'dirichlet-0.9'

    if store:
        try:
            loader = torch.load(path)
        except:
            print('loader not found, initialize one')
            loader = basic_loader(num_clients, batch_size, loader_type)
    else:
        print('initialize a data loader')
        loader = basic_loader(num_clients, batch_size, loader_type)
    if store:
        torch.save(loader, path)
    return loader


def test_dataloader(test_batch_size):
    test_dataset = getDataset(train=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True)
    return test_loader


if __name__ == '__main__':
    print("#Initialize a network")
    net = Net()
    batch_size = 10
    y = net(torch.randint(25002, (batch_size, 500)))
    print(f"Output shape of the network with batchsize {batch_size}:\t {y.size()}")

    print("\n#Initialize dataloaders")
    loader_types = ['iid', 'dirichlet']
    for i in range(len(loader_types)):
        loader = train_dataloader(10, loader_types[i])
        print(f"Initialized {len(loader)} loaders (type: {loader_types[i]}), each with batch size {loader.bsz}.\
        \nThe size of dataset in each loader are:")
        print([len(loader[i].dataset) for i in range(len(loader))])
        print(f"Total number of data: {sum([len(loader[i].dataset) for i in range(len(loader))])}")

    print("\n#Feeding data to network")
    x = next(iter(loader[i]))[0]
    y = net(x)
    print(f"Size of input:  {x.shape} \nSize of output: {y.shape}")
