import os
import torch
import random
import copy
from fairseq.models.roberta import RobertaModel, RobertaHubInterface
from fairseq import hub_utils
from fairseq.data.data_utils import collate_tokens
from tqdm import tqdm
import numpy as np
from sklearn.preprocessing import MinMaxScaler


EVAL_OFTEN = True
EVAL_EVERY = 10000


roberta = torch.hub.load('pytorch/fairseq', 'roberta.base')
roberta.cuda()
device='cuda'



train_in = [l.rstrip('\n') for l in open('../train/in.tsv',newline='\n').readlines()] # shuffled
dev_in = [l.rstrip('\n') for l in open('../dev-0/in.tsv',newline='\n').readlines()] # shuffled

train_year = [float(l.rstrip('\n')) for l in open('../train/expected.tsv',newline='\n').readlines()]
dev_year = [float(l.rstrip('\n')) for l in open('../dev-0/expected.tsv',newline='\n').readlines()]

dev_in_not_shuffled = copy.deepcopy(dev_in) # not shuffled
test_in = [l.rstrip('\n') for l in open('../test-A/in.tsv',newline='\n').readlines()] # not shuffled

c = list(zip(train_in,train_year))
random.shuffle(c)
train_in, train_year = zip(*c) 
c = list(zip(dev_in,dev_year))
random.shuffle(c)
dev_in, dev_year = zip(*c) 

scaler = MinMaxScaler()

train_year_scaled = scaler.fit_transform(np.array(train_year).reshape(-1,1))
dev_year_scaled = scaler.transform(np.array(dev_year).reshape(-1,1))


class RegressorHead(torch.nn.Module):
    def __init__(self):
        super(RegressorHead, self).__init__()
        self.linear1 = torch.nn.Linear(768,300)
        self.linear2 = torch.nn.Linear(300,1)
        self.linearxxx = torch.nn.Linear(768,1)
        self.dropout1 = torch.nn.Dropout(0.0)
        self.dropout2 = torch.nn.Dropout(0.0)
        self.m =  torch.nn.LeakyReLU(0.1)
    def forward(self,x):
        #x = self.dropout1(x)
        #x = self.linear1(x)
        #x = self.dropout2(x)
        x = self.linearxxx(x)
        x = self.m(x)
        x = -self.m(-x +1 ) +1
        return x 

regressor_head = RegressorHead().to(device)

optimizer = torch.optim.Adam(list(roberta.parameters()) + list(regressor_head.parameters()), lr = 1e-6)
criterion = torch.nn.MSELoss(reduction='sum').to(device)

BATCH_SIZE = 1
def get_train_batch(dataset_in,dataset_y):
    for i in tqdm(range(0,len(dataset_in), BATCH_SIZE)):
        batch_of_text = dataset_in[i:i+BATCH_SIZE]
        
        #batch = collate_tokens([roberta.encode(p)[:512]  for p in batch_of_text], pad_idx=1)
        batch = roberta.encode(batch_of_text[0])
        output= None
        for j in range(0,1,512): # only first 512 tokens instead of all
            if output is None:
                output = roberta.extract_features(batch[j:j+512])
            else:
                output_new = roberta.extract_features(batch[j:j+512])
                output = torch.cat((output, output_new),1)
        features = torch.mean(output,1)
        years = torch.FloatTensor(dataset_y[i:i+BATCH_SIZE]).to(device).squeeze()

        yield features, years


def eval():
    criterion_eval = torch.nn.MSELoss(reduction='sum')
    roberta.eval()
    regressor_head.eval()
    loss = 0.0
    loss_clipped = 0.0
    loss_scaled = 0.0
    for batch, year in tqdm(get_train_batch(dev_in,dev_year_scaled)):

        x = regressor_head(batch.to(device)).squeeze()
        x_clipped = torch.clamp(x,0.0,1.0)

        original_x =  torch.FloatTensor(scaler.inverse_transform(x.detach().cpu().numpy().reshape(1,-1)))
        original_x_clipped =  torch.FloatTensor(scaler.inverse_transform(x_clipped.detach().cpu().numpy().reshape(1,-1)))
        original_year =  torch.FloatTensor(scaler.inverse_transform(year.detach().cpu().numpy().reshape(1,-1)))

        loss_scaled += criterion_eval(x, year).item()
        loss += criterion_eval(original_x, original_year).item()
        loss_clipped += criterion_eval(original_x_clipped, original_year).item()
    print(' full valid loss scaled: ' + str(np.sqrt(loss_scaled/len(dev_year))))
    print(' full valid loss: ' + str(np.sqrt(loss/len(dev_year))))
    print(' full valid loss clipped: ' + str(np.sqrt(loss_clipped/len(dev_year))))

def eval_short():
    criterion_eval = torch.nn.MSELoss(reduction='sum')
    roberta.eval()
    regressor_head.eval()
    loss = 0.0
    loss_clipped = 0.0
    loss_scaled = 0.0
    for batch, year in tqdm(get_train_batch(dev_in[:1000],dev_year_scaled[:1000])):

        x = regressor_head(batch.to(device)).squeeze()
        x_clipped = torch.clamp(x,0.0,1.0)

        original_x =  torch.FloatTensor(scaler.inverse_transform(x.detach().cpu().numpy().reshape(1,-1)))
        original_x_clipped =  torch.FloatTensor(scaler.inverse_transform(x_clipped.detach().cpu().numpy().reshape(1,-1)))
        original_year =  torch.FloatTensor(scaler.inverse_transform(year.detach().cpu().numpy().reshape(1,-1)))

        loss_scaled += criterion_eval(x, year).item()
        loss += criterion_eval(original_x, original_year).item()
        loss_clipped += criterion_eval(original_x_clipped, original_year).item()
    print('valid loss scaled: ' + str(np.sqrt(loss_scaled/1000)))
    print('valid loss: ' + str(np.sqrt(loss/1000)))
    print('valid loss clipped: ' + str(np.sqrt(loss_clipped/len(dev_year))))


def train_one_epoch():
    roberta.train()
    regressor_head.train()
    loss_value=0.0
    iteration  = 0
    for batch, year in get_train_batch(train_in,train_year_scaled):
        iteration +=1
        roberta.zero_grad()
        regressor_head.zero_grad()
        #import pdb; pdb.set_trace()

        x = regressor_head(batch.to(device)).squeeze()

        loss = criterion(x, year)
        loss_value += loss.item()
        loss.backward()
        optimizer.step()

        roberta.zero_grad()
        regressor_head.zero_grad()


        if EVAL_OFTEN and (iteration > 1) and (iteration % EVAL_EVERY == 1):
            print('train loss: ' + str(np.sqrt(loss_value / EVAL_EVERY)))
            eval_short()
            roberta.train()
            regressor_head.train()
            loss_value = 0.0
    #print('train loss: ' + str(loss_value/len(train_year)))


def predict_dev():
    roberta.eval()
    regressor_head.eval()
    f_out = open('../dev-0/out.tsv','w')
    for batch, year in tqdm(get_train_batch(dev_in_not_shuffled,dev_year_scaled)):
        #batch_first = roberta.extract_features(batch)[:,0].to(device)
        x = regressor_head(batch).squeeze()
        x_clipped = torch.clamp(x,0.0,1.0)
        original_x_clipped =  scaler.inverse_transform(x_clipped.detach().cpu().numpy().reshape(1,-1))
        for y in original_x_clipped[0]:
            f_out.write(str(y) + '\n')
    f_out.close()

def predict_test():
    roberta.eval()
    regressor_head.eval()
    f_out = open('../test-A/out.tsv','w')
    for batch, year in tqdm(get_train_batch(test_in,dev_year_scaled)):
        #batch_first = roberta.extract_features(batch)[:,0].to(device)
        x = regressor_head(batch).squeeze()
        x_clipped = torch.clamp(x,0.0,1.0)
        original_x_clipped =  scaler.inverse_transform(x_clipped.detach().cpu().numpy().reshape(1,-1))
        for y in original_x_clipped[0]:
            f_out.write(str(y) + '\n')                                                                                                                                                                                                        
    f_out.close()


roberta.load_state_dict(torch.load('checkpoints/roberta_to_regressor3.pt'))
regressor_head.load_state_dict(torch.load('checkpoints/regressor_head3.pt'))
predict_dev()
predict_test()
