# -*- coding: utf-8 -*-

import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from prettytable import PrettyTable
import math
import os
import logging
import time
import datetime
from tqdm import tqdm
import argparse
from torch import nn
import torch as th


from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup
from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType
from utils import dp , dp_sq


####################### To omit randomness ########################
def set_seed(seed = 4389572):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
set_seed(4389572)

############# Parameter counter #########################
def count_parameters(model):
        table = PrettyTable(["Modules", "Parameters"])
        total_params = 0
        for name, parameter in model.named_parameters():
                if not parameter.requires_grad:
                    continue
                params = parameter.numel()
                table.add_row([name, params])
                total_params+=params
        #print(table)
        print(f"Total Trainable Params: {total_params}")
        return total_params
                                                                



# setup logging
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
if not os.path.isdir('logs'):
    os.makedirs('logs')
log_file = 'logs/log_{}.log'.format(datetime.datetime.now().strftime("%Y_%B_%d_%I-%M-%S%p"))
open(log_file, 'a').close()

# create logger
logger = logging.getLogger('main')
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')

# create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(ch)

# add to log file
fh = logging.FileHandler(log_file)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
logger.addHandler(fh)

def adjust_learning_rate(optimizer, epoch, args):
    # lr = args.learning_rate * (0.2 ** (epoch // 2))
    
    lr_adjust = {epoch: args.lr * (0.9 ** ((epoch-1) // 1))}
    # elif args.lradj=='type2':
    #     lr_adjust = {
    #         2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 
    #         10: 5e-7, 15: 1e-7, 20: 5e-8
    #     }
    if epoch in lr_adjust.keys():
        lr = lr_adjust[epoch]
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        #print('Updating learning rate to {}'.format(lr))


def log(str): logger.info(str)
log('Is GPU available? {}'.format(torch.cuda.is_available()))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser(description="DomainGen_Graph")
datasets = ['ONP', 'Moons', 'MNIST', 'Elec2']
parser.add_argument("--prompt", default=0, type=int,
                    help="if prompting applied")
parser.add_argument("--no_features", default=1, type=int,
                    help="Numbder of features.")
parser.add_argument("--gen", default=0, type=int,
                    help="prompt generatin .")
parser.add_argument("--dataset", default="Moons", type=str,
                    help="one of: {}".format(", ".join(sorted(datasets))))

# Hyper-parameters
parser.add_argument("--noise_dim", default=16, type=float,
                    help="the dimension of the LSTM input noise.")
parser.add_argument("--num_rnn_layer", default=1, type=float,
                    help="the number of RNN hierarchical layers.")
parser.add_argument("--latent_dim", default=16, type=float,
                    help="the latent dimension of RNN variables.")

parser.add_argument("--noise_type", choices=["Gaussian", "Uniform"], default="Gaussian",
                    help="The noise type to feed into the generator.")
parser.add_argument("--num_workers", default=0, type=int,
                    help="the number of threads for loading data.")
parser.add_argument("--epoches", default=20, type=int,
                    help="the number of epoches for each task.")
parser.add_argument("--batch_size", default=128, type=int,
                    help="the number of epoches for each task.")
parser.add_argument("--lr", default=1e-4, type=float,
                    help="the unified learning rate for each single task.")
parser.add_argument("--num_task", default=8, type=int,
                    help="number of tasks")
parser.add_argument("--is_test", default=True, type=bool,
                    help="if this is a testing period.")
parser.add_argument("--model_channel", default=128, type=int,
                    help="num model channel" )
parser.add_argument("--hidden_dim", default= 256, type=int,
                    help="num hidden dim .")
parser.add_argument("--emb_dim_p", default= 256, type=int,
                    help="emb dim p .")
parser.add_argument("--seq", default= 1, type=int,
                    help="sequentional training or not.")
parser.add_argument("--num_layers", default=2, type=int,
                    help="num hidden dim .")                 
parser.add_argument("--e_type", default="max", type=str,
                    help="e type .")                 
parser.add_argument("--num_layers_p", default=2, type=int,
                    help="num hidden dim .")      

parser.add_argument("--num_layers_m", default=2, type=int,
                    help="num hidden dim .")   

parser.add_argument("--hp", default=2, type=int,
                    help="num hidden dim prompt .")   

args = parser.parse_args()
hidden_dim=args.hidden_dim
num_layers_p=args.num_layers_p
num_layers_m=args.num_layers_m
emb_dimp =args.emb_dim_p

prompting = True if args.prompt!=0 else False 

hp=args.hp
seq = args.seq
print("Seq is", seq)

if args.gen==0:
    generatirng_prompts= False
else:
    generatirng_prompts= True

dataset =args.dataset
batch_size =args.batch_size



class TransformerModel(nn.Module):
    def __init__(
        self,
        num_feature,
        in_channels,
        model_channels,
        out_channels  
        ):
        super().__init__()
        self.in_channels =   in_channels
        self.model_channels =  model_channels
        self.out_channels = in_channels
        self.pos_encoder=PositionalEncoding(model_channels*2, 0.001)
        self._feature_size = self.model_channels
        self.input_emb = nn.Linear(self.in_channels, self.model_channels)
        self.input_emb2 = nn.Linear(self.model_channels, self.model_channels*2)
        self.activation = nn.GELU()
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(self.model_channels*2, 1, hidden_dim, 0.01, self.activation, batch_first=True), num_layers_m)
        self.output_linear2 = nn.Linear(self.model_channels, self.out_channels)
        self.output_linear1 = nn.Linear(self.model_channels*2, self.model_channels)
    def forward(self, x  ,P_T= None,P_S=None):
        x=x.float().unsqueeze(2)
        
        x2=torch.zeros_like(x)
        x=torch.cat([x,x2],1)  
        he = self.input_emb(x)

        he = self.input_emb2(he)

        he = self.transformer(self.pos_encoder(he))
        he=self.output_linear1(he)
        out = self.output_linear2(he)
        return out

def evaluation(dataloader, backbone,arsgs, P_T=None, P_S=None):
    backbone.eval()
    mses = []
    for X, Y, mask in dataloader:   
        X, Y  = X.float().to("cuda"), Y.float().to("cuda")
        with torch.no_grad(): 
            pred = backbone(X) 
           
            mse = F.mse_loss(pred.squeeze()[:,96:], Y.squeeze() )
            mses.append(mse)
    #import pdb ; pdb.set_trace()
    log("Average RMSE is :{}".format(sum(mses)/len(mses)))


class PROMPTEmbedding(nn.Module):
    def __init__(self, 
                n_tokens: int = 1,  
                random_range: float = 0.5):
        super(PROMPTEmbedding, self).__init__()
        w = torch.empty(1,args.model_channel)
        self.n_tokens = n_tokens
        self.learned_embedding = nn.parameter.Parameter(nn.init.uniform_(w))
    def forward(self, P_T, P_S, x):
        learned_embedding = self.learned_embedding.repeat(x.shape[0], 1, 1)
        return  torch.cat([learned_embedding,P_T,P_S], 1)

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        d_model= d_model

        position = th.arange(max_len).unsqueeze(1)
        div_term = th.exp(th.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = th.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = th.sin(position * div_term)
        pe[0, :, 1::2] = th.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
        
    x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        x = x + self.pe[0:1, :x.size(1)]
        return  self.dropout(x)
     

def main(arsgs):
        
    output_directory='outputs/outputs-{}-ch-{}-dim-{}-lap-{}-lam{}-embp-{}-hp-{}'.format(args.dataset, args.model_channel, hidden_dim, args.num_layers_p, args.num_layers_m, emb_dimp, hp)
    model_directory='models/models-{}-ch-{}-dim-{}-lap-{}-lam{}-embp-{}-hp-{}'.format(args.dataset, args.model_channel, hidden_dim, args.num_layers_p, args.num_layers_m, emb_dimp, hp)
    embds_directory='embds/embds-{}-ch-{}-dim-{}-lap-{}-lam{}-embp-{}-hp-{}'.format(args.dataset, args.model_channel, hidden_dim, args.num_layers_p, args.num_layers_m, emb_dimp, hp)



    if not os.path.isdir(output_directory):
        os.makedirs(output_directory)
    if not os.path.isdir(embds_directory):
        os.makedirs(embds_directory)
    if not os.path.isdir(model_directory):
        os.makedirs(model_directory)
        os.makedirs(model_directory+"_freeze")
        os.makedirs(model_directory+"_prompt")
    no_features= args.no_features
    log('use {} data'.format(args.dataset))
    log('-'*40)

    if args.dataset == 'Moons':
        num_task=10
        data_size=2
        num_instances=220
    elif args.dataset == 'MNIST':
        num_task=11
        data_size=2
        num_instances=200
    elif args.dataset == 'ONP':
        num_task=6
        data_size=58
        num_instances=None
    elif args.dataset == 'Elec2':
        num_task=41
        data_size=8
        num_instances=None
    elif args.dataset == 'RC':
        num_task=21
        data_size=14
        num_instances=None
    elif args.dataset == 'WR':
        num_task=16
        data_size=12
        num_instances=None

    if seq ==1:
        dataloaders = dp_sq(args, num_task, num_instances)
    else:
        dataloaders = dp(args, num_task, num_instances)

    backbone =  TransformerModel(1, no_features ,args.model_channel,1).cuda() 

    count_parameters(backbone)
    optimizer = torch.optim.Adam(backbone.parameters(), lr=args.lr) 
    for task_id, dataloader in enumerate(dataloaders[:-1]):
            #print(task_id)
            for epoch in range(args.epoches):
                maes=[]
                for X, Y, mask in  dataloader:
                        X, Y, mask  = X.float().to(device), Y.float().to(device) , mask.float().to(device) 
                        optimizer.zero_grad()
                        backbone.train()
                        pred = backbone(X)
                        #Y=torch.cat([X.squeeze(),Y.squeeze()],1)
                        #import pdb ; pdb.set_trace()
                        loss = F.mse_loss(pred.squeeze()[:,96:], Y.squeeze())
                        maes.append(loss)
                        loss.backward()
                        optimizer.step()
                print(sum(maes)/len(maes))  
                if epoch ==args.epoches-1 :  
                    print("hey")
                    torch.save({'epoch': epoch,
                'model_state_dict': backbone.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
                            }, model_directory+"_freeze/"+"model.pth")
                adjust_learning_rate(optimizer, epoch+1, args)
    evaluation(dataloaders[-1], backbone, arsgs)


        

       
        

if __name__ == "__main__":

    print("Start Training...")
    
    # Initialize the time
    timestr = time.strftime("%Y%m%d-%H%M%S")
    
    main(args)









