import os
import json
import torch
from torch.utils.data import TensorDataset

import datasets

import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, T5ForConditionalGeneration

DEBERTA_MODEL_NAMES = ["microsoft/deberta-v3-xsmall", "microsoft/deberta-v3-small", "microsoft/deberta-v3-base", "microsoft/deberta-v3-large"]

def get_tokenizer(model):
    tokenizer = AutoTokenizer.from_pretrained(model)
    return tokenizer

# optional
NUM_LABELS = 3

def load_model(args, load_path=None, cache_dir=None):
    if cache_dir is not None:
        config = AutoConfig.from_pretrained(args.model, num_labels=NUM_LABELS, cache_dir=cache_dir, dropout_rate=args.dropout)
    else:
        config = AutoConfig.from_pretrained(args.model, num_labels=NUM_LABELS, dropout_rate=args.dropout, attention_dropout_prob=args.dropout, hidden_dropout_prob=args.dropout)

    model = AutoModelForSequenceClassification.from_pretrained(args.model, config=config)
    
    if load_path is not None:
        model.load_state_dict(torch.load(load_path))

    if args.ngpus > 0:
        model.cuda()
        model = torch.nn.DataParallel(model, device_ids=[i for i in range(args.ngpus)])

    print('\nPretrained model "{}" loaded'.format(args.model))
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0}
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    return model, optimizer

def split_data(split, data, nsplits=5):
    all_idxs = np.arange(len(data))
    train_mask = np.ones(len(data)).astype(bool)
    test_mask = np.zeros(len(data)).astype(bool)
    start, end = (len(data) // nsplits)*split, (len(data) // nsplits)*(split+1)
    train_mask[start:end] = False
    test_mask[start:end] = True
    train_idxs = all_idxs[train_mask]
    test_idxs = all_idxs[test_mask]
    train_data = torch.utils.data.Subset(data, train_idxs)
    test_data = torch.utils.data.Subset(data, test_idxs)
    return train_data, test_data


def load_paired_cm_sentences(data_dir, split="train"):
    if "long" in split:
        path = os.path.join(data_dir, "cm_{}.csv".format(split.split("long_")[1]))
        df = pd.read_csv(path)
        df = df[df["is_short"] == False]
        raise NotImplementedError
    else:
        path = os.path.join(data_dir, "cm_{}.csv".format(split))
        df = pd.read_csv(path)

    if split == "ambig":
        labels = [-1 for _ in range(df.shape[0])]
        sentences = [df.iloc[i, 0] for i in range(df.shape[0])]
        raise NotImplementedError
    else:
        labels = [df.iloc[i, 0] for i in range(df.shape[0])]
        sentences = [df.iloc[i, 1] for i in range(df.shape[0])]

    paired_sentences = []
    for s, l in zip(sentences, labels):
        if l == 1:
            paired_sentences.append("")
            paired_sentences.append(s)
        elif l == 0:
            paired_sentences.append(s)
            paired_sentences.append("")
        else:
            raise ValueError("Unknown label {}".format(l))
    paired_labels = [[1,0,0] for _ in range(len(paired_sentences))]
    
    return paired_sentences, paired_labels

def load_util_sentences(data_dir, split="train"):
    path = os.path.join(data_dir, "util_{}.csv".format(split))
    df = pd.read_csv(path, header=None)
    sentences = []
    for i in range(df.shape[0]):
        sentences.append(df.iloc[i, 0])
        sentences.append(df.iloc[i, 1])
    labels = [[0,1,0] for _ in range(len(sentences))]
    return sentences, labels

types_of_power = ['coercive', 'reward', 'legitimate', 'referent', 'expert', 'informational', 'economic', 'political', 'military', 'personal']
def load_power_sentences(data_dir, split="train"):
    L = [0,0,1]

    train_sentences, test_sentences, train_labels, test_labels = [], [], [], []

    data = json.load(open(os.path.join(data_dir, 'power_examples.json')))
    for p_type in data.keys():
        for d in data[p_type]:
            res = d['res']
            if len(res.keys()) < 5: continue

            if np.random.random() < 0.1:
                sentences, labels = test_sentences, test_labels
            else:
                sentences, labels = train_sentences, train_labels

            for i1, i2 in zip(range(1, 5), range(2, 6)):
                s1, s2 = res[str(i1)], res[str(i2)]
                sentences.append(s1)
                sentences.append(s2)
                # label = [p == p_type for p in types_of_power]
                label = L
                labels.extend([label] * 2)

    data = json.load(open(os.path.join(data_dir, 'power_examples_2.json')))
    for p_type in data.keys():
        for d in data[p_type]:
            res = d['res']
            if len(res['scenarios'].keys()) < 5: continue

            res_scenarios = res['scenarios']
            context = res['context']

            if np.random.random() < 0.1:
                sentences, labels = test_sentences, test_labels
            else:
                sentences, labels = train_sentences, train_labels

            for i1, i2 in zip(range(1, 5), range(2, 6)):
                s1, s2 = res_scenarios[str(i1)], res_scenarios[str(i2)]
                sentences.append(context + " " + s1)
                sentences.append(context + " " + s2)
                # label = [p == p_type for p in types_of_power]
                label = L
                labels.extend([label] * 2)

    path = os.path.join(data_dir, "all_data.csv")
    df = pd.read_csv(path, header=0, index_col=0)
    for i in range(df.shape[0]):
        context = df.iloc[i, 0]

        if np.random.random() < 0.1:
            sentences, labels = test_sentences, test_labels
        else:
            sentences, labels = train_sentences, train_labels

        if df.iloc[i, 3] == 'A > B':
            sentences.append(context + " " + df.iloc[i, 1])
            sentences.append(context + " " + df.iloc[i, 2])
        else:
            sentences.append(context + " " + df.iloc[i, 2])
            sentences.append(context + " " + df.iloc[i, 1])
        # label = [p in df.iloc[i, 4] for p in types_of_power]
        label = L
        labels.extend([label] * 2)
    
    if split == "train":
        return train_sentences, train_labels
    elif split == "test":
        return test_sentences, test_labels

def load_process_data(args, data_dir, dataset, split="train", model_name: str = None):
    load_fn = {"paired_cm": load_paired_cm_sentences, "util": load_util_sentences, "power": load_power_sentences}[dataset]
    data_dir = os.path.join(data_dir, dataset.removeprefix("paired_"))
    sentences, labels = load_fn(data_dir, split=split)
    labels = torch.tensor(labels)
 
    tokenizer = get_tokenizer(args.model)
    tokenized_inputs = tokenizer(sentences, padding='max_length', truncation=True, max_length=args.max_length, return_tensors="pt")
    inputs, masks = tokenized_inputs["input_ids"], tokenized_inputs["attention_mask"]
    within_bounds = [inputs[i, -1] == tokenizer.pad_token_id or inputs[i, 0] == tokenizer.pad_token_id for i in range(len(inputs))]
    if np.mean(within_bounds) < 1:
        print("{} {} fraction of examples within context window ({} tokens): {:.3f}".format(dataset, split, args.max_length, np.mean(within_bounds)))
    even_mask = [i for i in range(inputs.shape[0]) if i % 2 == 0]
    odd_mask = [i for i in range(inputs.shape[0]) if i % 2 == 1]
    even_inputs, odd_inputs = inputs[even_mask], inputs[odd_mask]
    even_labels, odd_labels = labels[even_mask], labels[odd_mask]
    even_masks, odd_masks = masks[even_mask], masks[odd_mask]
    inputs = torch.stack([even_inputs, odd_inputs], axis=1)
    labels = torch.stack([even_labels, odd_labels], axis=1)
    masks = torch.stack([even_masks, odd_masks], axis=1)

    data = TensorDataset(inputs, masks, labels)
    return data
