import torch
from transformers import GPT2Tokenizer, GPT2Model
from transformers.models.gpt2.modeling_gpt2 import GPT2Config

from DualTime import *

import numpy as np
import json
import pickle
from torch.utils.data import Dataset
import copy
import torch.optim as optim
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm import tqdm
import math
import logging
from datetime import datetime
from sklearn.preprocessing import LabelEncoder
from torch.utils.tensorboard import SummaryWriter
from types import SimpleNamespace
import argparse
from torch.optim.lr_scheduler import StepLR
import random


class InstructionDataset_TimeSeries(Dataset):
    def __init__(self, data_path,  partition="train",  args = None):
        self.args = args
        with open(data_path + partition+'.pkl', 'rb') as file:
            data = pickle.load(file)
        self.text_ann = data['description']
        self.ts_ann = data['X']
        self.labels = data['Y']

        encoder = LabelEncoder()
        self.labels = encoder.fit_transform(self.labels)
        self.sl_labels = data['sl_Y']
        encoder2 = LabelEncoder()
        self.sl_labels = encoder2.fit_transform(self.sl_labels)
        print('Partition:', partition, ', Size:', len(self.labels))

    def __len__(self):
        return len(self.ts_ann)

    def __getitem__(self, index):
        ts_sample = self.ts_ann[index, ...]
        description = self.text_ann[index]
        if not self.args.detection:
            labels = self.labels[index]
        else:
            labels = self.sl_labels[index]


        if not isinstance(description, str):
            description = ' '



        return description, labels, torch.tensor(ts_sample)


def get_args_parser():
    parser = argparse.ArgumentParser("Dual Multi-modal Modeling", add_help=False)

    parser.add_argument(
        "--cuda",
        type=int,
        default=0,
    )
    parser.add_argument(
        "--gate_fusion",
        action = "store_true",
        help="Use gate fusion to combine the two modalities",
    )

    parser.add_argument(
        "--num_epochs",
        type=int,
        default=40,
    )

    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
    )

    parser.add_argument(
        "--test_only",
        action="store_true",
    )

    parser.add_argument(
        "--test_model_path",
        type=str,
        default=None,
    )

    parser.add_argument(
        "--patch_len",
        type=int,
        default=25,
    )

    parser.add_argument(
        '--lr',
        type=float,
        default=0.001,
    )

    parser.add_argument(
        '--detection',
        action="store_true",
    )
    parser.add_argument(
        "--adapter_layers",
        type=int,
        default=11,
    )
    parser.add_argument(
        "--adapter_len",
        type=int,
        default=5,
    )
    parser.add_argument(
        "--revin",
        action="store_false",
    )
    parser.add_argument("--model_type", type=str, default='', help='Model variant to use')
    parser.add_argument(
        '--gpt_dropout',
        type=float,
        default=0.1,
        help = 'GPT2 dropout rate'
    )
    parser.add_argument(
        "--exp_lr",
        action="store_true",
        help = 'Whether to use exponential learning rate scheduler',
    )
    parser.add_argument('--exp_gamma',        type=float,
        default=0.85,
        help = 'Exponential learning rate scheduler gamma'
    )
    parser.add_argument('--bert_projector', action='store_true', help='Use a linear layer to project the hidden states of the bert model')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    return parser


if __name__ == '__main__':
    def set_seed(seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


    args = get_args_parser().parse_args()
    print(args)

    seed = args.seed
    set_seed(seed)
    device = 'cuda:'+str(args.cuda)
    model_name = args.model_type.split('_')[-1]

    current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if not args.detection:
        log_file_path = f'logs/mm_{model_name}_dual_adapter_training_logs_{current_time}.log'
    else:
        log_file_path = f'logs/detection_mm_{model_name}_dual_adapter_training_logs{current_time}.log'

    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        filename=log_file_path,
                        filemode='a')


    args_str = ', '.join(f'{k}={v}' for k, v in vars(args).items())
    logging.info(f'Args Input��{args_str}')
    if args.test_only:
        logging.info(f'Test Only Mode')
        args.num_epochs = 0

    if not args.detection:
        linear_model = torch.nn.Linear(768, 5)
    else:
        linear_model = torch.nn.Linear(768, 4)
    criterion = nn.CrossEntropyLoss()



    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token

    dataset_config = {
    'ts_emb_patch_len': args.patch_len,
     'batch_size': 64,
     'context_points': 1000,
     'patch_len': 25,
     'stride': 25,
     'revin': 1,
    'mask_ratio': 0.0,
     'vars': 12
    }

    dataset_config = SimpleNamespace(**dataset_config)
    config_ = GPT2Config(use_cache=False, adapter_layers = args.adapter_layers, adapter_len = args.adapter_len, adapter_head = 12, attn_pdrop= args.gpt_dropout, embd_pdrop= args.gpt_dropout,
                         resid_pdrop= args.gpt_dropout, bert_projector = args.bert_projector, gate_fusion = args.gate_fusion, ts_config = dataset_config)

    model = DualTime(config_)


    params_ = torch.load('./gpt2_pretrained/pretrained_gpt2.pth')
    model.load_state_dict(params_, strict=False)
    print('GPT2 params loaded...')

    tuned_list = ['adapter', 'gate']


    for name, param in model.named_parameters():
        if any(tuned in name for tuned in tuned_list):
            param.requires_grad = True
        else:
            param.requires_grad = False

    optimizer = optim.AdamW(list(linear_model.parameters()) + list(model.parameters()), lr=args.lr)
    scheduler = StepLR(optimizer, step_size=5, gamma=0.3)
    ############################################
    if args.exp_lr:
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.exp_gamma)

    data_path = './datasets/'
    dataset_train = InstructionDataset_TimeSeries(
        data_path=data_path ,  partition="train", args = args)
    dataset_val = InstructionDataset_TimeSeries(
        data_path=data_path ,  partition="val", args = args)
    dataset_test = InstructionDataset_TimeSeries(
        data_path=data_path ,  partition="test", args = args)

    data_loader_train = torch.utils.data.DataLoader( dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=False)
    data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False)
    data_loader_val = torch.utils.data.DataLoader( dataset_val, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False)

    model = model.to(device)
    linear_model = linear_model.to(device)


    epoch = 0
    best_val_loss = 999999
    step = 0
    num_epochs = args.num_epochs
    val_epoch_list = []
    results_list = []
    for epoch in tqdm(range(num_epochs), desc='Epochs'):
        logging.info(f'Epoch {epoch+1}')
        linear_model.train()
        model.train()
        loss_train = 0
        with tqdm(enumerate(data_loader_train), total=len(data_loader_train), desc=f'Epoch {epoch + 1}') as tepoch:
            for i, batch in tepoch:
                # print(i)
                raw_text, labels, ts = batch
                labels = labels.to(device)
                ts = ts.to(device).to(torch.float32)

                # with torch.no_grad():
                output_ts = model.ts_main_forward(raw_text, ts_sample = ts) # keys: last_hidden_state, past_key_values
                hidden_emb_ts = output_ts[:, -1, :] # (batch_size, seq_len, hidden_size)
                output_text = model.text_main_forward(raw_text, ts_sample = ts)
                hidden_emb_text = output_text[:, -1, :] # (batch_size, seq_len, hidden_size)

                hidden_emb = (hidden_emb_ts + hidden_emb_text) / 2

                scores = linear_model(hidden_emb)
                loss = criterion(scores, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                loss_train += loss.item()
                tepoch.set_postfix(loss=loss.item())


        logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss_train/i:.4f}')
        scheduler.step()

        linear_model.eval()
        model.eval()
        val_loss = 0
        with torch.no_grad():
            with tqdm(enumerate(data_loader_val), total=len(data_loader_val), desc=f'Val Epoch {epoch + 1}', disable=True) as tepoch:
                for i, batch in tepoch:
                # print(i)
                    raw_text, labels, ts = batch
                    labels = labels.to(device)
                    ts = ts.to(device).to(torch.float32)

                    # with torch.no_grad():
                    output_ts = model.ts_main_forward(raw_text, ts_sample=ts)  # keys: last_hidden_state, past_key_values
                    hidden_emb_ts = output_ts[:, -1, :]  # (batch_size, seq_len, hidden_size)
                    output_text = model.text_main_forward(raw_text, ts_sample=ts)
                    hidden_emb_text = output_text[:, -1, :]  # (batch_size, seq_len, hidden_size)

                    hidden_emb = (hidden_emb_ts + hidden_emb_text) / 2

                    scores = linear_model(hidden_emb)
                    val_loss += criterion(scores, labels)
        logging.info(f'Epoch {epoch+1} Val Loss: {val_loss/i:.4f}')
        val_epoch_list.append(val_loss.cpu().numpy())
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            linear_model_state_dict = linear_model.state_dict()
            model_state_dict = model.state_dict()
            selected_model_state_dict = {k: v for k, v in model.state_dict().items() if
                                          any(tuned in k for tuned in tuned_list)}
            # print('Saving Best Model...')
            combined_state_dict = {'linear_model': linear_model_state_dict, 'model': selected_model_state_dict}
            torch.save(combined_state_dict, f'save_models/combined_mm_model_best_{current_time}.pth')


        all_preds = []
        all_targets = []
        with torch.no_grad():
            with tqdm(enumerate(data_loader_test), total=len(data_loader_test), desc=f'Test Epoch {epoch + 1}', disable=True) as tepoch:
                for i, batch in tepoch:

                    raw_text, labels, ts = batch
                    labels = labels.to(device)
                    ts = ts.to(device).to(torch.float32)

                    output_ts = model.ts_main_forward(raw_text, ts_sample = ts) # keys: last_hidden_state, past_key_values
                    hidden_emb_ts = output_ts[:, -1, :] # (batch_size, seq_len, hidden_size)
                    output_text = model.text_main_forward(raw_text, ts_sample = ts)
                    hidden_emb_text = output_text[:, -1, :] # (batch_size, seq_len, hidden_size)
                    hidden_emb = (hidden_emb_ts + hidden_emb_text) / 2
                    scores = linear_model(hidden_emb)
                    _, predictions = scores.max(1)
                    all_preds.extend(predictions.cpu().numpy())
                    all_targets.extend(labels.cpu().numpy())



        accuracy = accuracy_score(all_targets, all_preds)
        precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='macro')
        precision_individual, recall_individual, f1_individual, _ = precision_recall_fscore_support(all_targets, all_preds, average=None)


        logging.info(f'Accuracy: {accuracy:.4f}')
        logging.info(f'Macro Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}')
        logging.info(f'Per Class Precision: {precision_individual}')
        logging.info(f'Per Class Recall: {recall_individual}')
        logging.info(f'Per Class F1: {f1_individual}')
        logging.info(f'______________________________________________________')
        logging.info(f'______________________________________________________')


        results_list.append([accuracy, precision, recall, f1, precision_individual, recall_individual, f1_individual])


    best_epoch = np.argmin(val_epoch_list)
    accuracy, precision, recall, f1, precision_individual, recall_individual, f1_individual = results_list[best_epoch]


    logging.info(f'______________________________________________________')
    logging.info(f'______________________________________________________')
    logging.info(f'Best Model Test Results')
    logging.info(f'Accuracy: {accuracy:.4f}')
    logging.info(f'Macro Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}')
    logging.info(f'Per Class Precision: {precision_individual}')
    logging.info(f'Per Class Recall: {recall_individual}')
    logging.info(f'Per Class F1: {f1_individual}')
    logging.info(f'______________________________________________________')
    logging.info(f'______________________________________________________')

