# -*- coding: utf-8 -*-

import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from prettytable import PrettyTable
import math
import os
import logging
import time
import datetime
from tqdm import tqdm
import argparse
from torch import nn
import torch as th


from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup
from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType
from utils import dp , dp_sq


####################### To omit randomness ########################
def set_seed(seed = 4389572):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
set_seed(4389572)

############# Parameter counter #########################
def count_parameters(model):
        table = PrettyTable(["Modules", "Parameters"])
        total_params = 0
        for name, parameter in model.named_parameters():
                if not parameter.requires_grad:
                    continue
                params = parameter.numel()
                table.add_row([name, params])
                total_params+=params
        #print(table)
        #print(f"Total Trainable Params: {total_params}")
        return total_params
                                                                



# setup logging
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
if not os.path.isdir('logs'):
    os.makedirs('logs')
log_file = 'logs/log_{}.log'.format(datetime.datetime.now().strftime("%Y_%B_%d_%I-%M-%S%p"))
open(log_file, 'a').close()

# create logger
logger = logging.getLogger('main')
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')

# create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(ch)

# add to log file
fh = logging.FileHandler(log_file)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
logger.addHandler(fh)


def log(str): logger.info(str)
log('Is GPU available? {}'.format(torch.cuda.is_available()))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser(description="DomainGen_Graph")
datasets = ['ONP', 'Moons', 'MNIST', 'Elec2']
parser.add_argument("--prompt", default=0, type=int,
                    help="if prompting applied")
parser.add_argument("--no_features", default=1, type=int,
                    help="Numbder of features.")
parser.add_argument("--gen", default=0, type=int,
                    help="prompt generatin .")
parser.add_argument("--dataset", default="Moons", type=str,
                    help="one of: {}".format(", ".join(sorted(datasets))))

# Hyper-parameters
parser.add_argument("--noise_dim", default=16, type=float,
                    help="the dimension of the LSTM input noise.")
parser.add_argument("--num_rnn_layer", default=1, type=float,
                    help="the number of RNN hierarchical layers.")
parser.add_argument("--latent_dim", default=16, type=float,
                    help="the latent dimension of RNN variables.")

parser.add_argument("--noise_type", choices=["Gaussian", "Uniform"], default="Gaussian",
                    help="The noise type to feed into the generator.")
parser.add_argument("--num_workers", default=0, type=int,
                    help="the number of threads for loading data.")
parser.add_argument("--epoches", default=20, type=int,
                    help="the number of epoches for each task.")
parser.add_argument("--batch_size", default=128, type=int,
                    help="the number of epoches for each task.")
parser.add_argument("--lr", default=1e-4, type=float,
                    help="the unified learning rate for each single task.")
parser.add_argument("--num_task", default=8, type=int,
                    help="number of tasks")
parser.add_argument("--is_test", default=True, type=bool,
                    help="if this is a testing period.")
parser.add_argument("--model_channel", default=128, type=int,
                    help="num model channel" )
parser.add_argument("--hidden_dim", default= 256, type=int,
                    help="num hidden dim .")

parser.add_argument("--emb_dim_p", default= 256, type=int,
                    help="emb dim p .")

parser.add_argument("--seq", default= 1, type=int,
                    help="sequentional training or not.")


parser.add_argument("--num_layers", default=2, type=int,
                    help="num hidden dim .")                 
parser.add_argument("--e_type", default="max", type=str,
                    help="e type .")                 
parser.add_argument("--num_layers_p", default=2, type=int,
                    help="num hidden dim .")      

parser.add_argument("--num_layers_m", default=2, type=int,
                    help="num hidden dim .")   

parser.add_argument("--hp", default=2, type=int,
                    help="num hidden dim prompt .")   

args = parser.parse_args()
hidden_dim=args.hidden_dim
num_layers_p=args.num_layers_p
num_layers_m=args.num_layers_m
emb_dimp =args.emb_dim_p
hp=args.hp
seq = args.seq
print("Seq is", seq)



if args.gen==0:
    generatirng_prompts= False
else:
    generatirng_prompts= True


model_chan =args.model_channel
dataset =args.dataset
batch_size =args.batch_size



output_directory='outputs/outputs-{}-ch-{}-dim-{}-lap-{}-lam{}-embp-{}-hp-{}'.format(args.dataset, args.model_channel, hidden_dim, args.num_layers_p, args.num_layers_m, emb_dimp, hp)
model_directory='models/models-{}-ch-{}-dim-{}-lap-{}-lam{}-embp-{}-hp-{}'.format(args.dataset, args.model_channel, hidden_dim, args.num_layers_p, args.num_layers_m, emb_dimp, hp)
embds_directory='embds/embds-{}-ch-{}-dim-{}-lap-{}-lam{}-embp-{}-hp-{}'.format(args.dataset, args.model_channel, hidden_dim, args.num_layers_p, args.num_layers_m, emb_dimp, hp)
PS_directory='PS/PS-{}-ch-{}-dim-{}-lap-{}-lam{}-embp-{}-hp-{}'.format(args.dataset, args.model_channel, hidden_dim, args.num_layers_p, args.num_layers_m, emb_dimp, hp)

def adjust_learning_rate(optimizer, epoch, args):
    # lr = args.learning_rate * (0.2 ** (epoch // 2))
    
    lr_adjust = {epoch: args.lr * (0.9 ** ((epoch-1) // 1))}
    # elif args.lradj=='type2':
    #     lr_adjust = {
    #         2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 
    #         10: 5e-7, 15: 1e-7, 20: 5e-8
    #     }
    if epoch in lr_adjust.keys():
        lr = lr_adjust[epoch]
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        #print('Updating learning rate to {}'.format(lr))




class TransformerModel(nn.Module):
    def __init__(
        self,
        num_feature,
        in_channels,
        model_channels,
        out_channels  
        ):
        super().__init__()
        self.in_channels =   in_channels
        self.model_channels =  model_channels
        self.out_channels =  in_channels
        self.pos_encoder=PositionalEncoding(model_channels*2, 0.001)
        self._feature_size = self.model_channels
        self.input_emb = nn.Linear(self.in_channels, self.model_channels)
        self.input_emb2 = nn.Linear(self.model_channels, self.model_channels*2)
        self.activation = nn.GELU()
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(self.model_channels*2, 1, hidden_dim, 0.01, self.activation, batch_first=True), num_layers_m)
        self.output_linear2 = nn.Linear(self.model_channels, self.out_channels)
        self.output_linear1 = nn.Linear(self.model_channels*2, self.model_channels)

        w = torch.empty(1,model_channels*2)
        self.learned_embedding = nn.parameter.Parameter(nn.init.xavier_uniform_(w,gain =4))

    def forward(self, x  ,P_T= None,P_S=None):
        x=x.float().unsqueeze(2)
        x2=torch.zeros_like(x)
        x=torch.cat([x,x2],1)  
        P_S =  self.learned_embedding.repeat(x.shape[0], 1, 1)
       
        he = self.input_emb(x)

        he = self.input_emb2(he)
        he = torch.cat([ he, P_S],1)
        he = self.transformer(self.pos_encoder(he))
        he=self.output_linear1(he[:,:-1])
        out = self.output_linear2(he)
        return out, P_S


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        d_model= d_model

        position = th.arange(max_len).unsqueeze(1)
        div_term = th.exp(th.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = th.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = th.sin(position * div_term)
        pe[0, :, 1::2] = th.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
        
    x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        x = x + self.pe[0:1, :x.size(1)]
        return  self.dropout(x)
     





def main(arsgs):
        
    output_directory='outputs/outputs-{}-ch-{}-dim-{}-lap-{}-lam{}-embp-{}-hp-{}'.format(args.dataset, args.model_channel, hidden_dim, args.num_layers_p, args.num_layers_m, emb_dimp, hp)
    model_directory='models/models-{}-ch-{}-dim-{}-lap-{}-lam{}-embp-{}-hp-{}'.format(args.dataset, args.model_channel, hidden_dim, args.num_layers_p, args.num_layers_m, emb_dimp, hp)
    embds_directory='embds/embds-{}-ch-{}-dim-{}-lap-{}-lam{}-embp-{}-hp-{}'.format(args.dataset, args.model_channel, hidden_dim, args.num_layers_p, args.num_layers_m, emb_dimp, hp)
    PS_directory='PS/PS-{}-ch-{}-dim-{}-lap-{}-lam{}-embp-{}-hp-{}'.format(args.dataset, args.model_channel, hidden_dim, args.num_layers_p, args.num_layers_m, emb_dimp, hp)

    if not os.path.isdir(output_directory):
        os.makedirs(output_directory)
    if not os.path.isdir(PS_directory):
        os.makedirs(PS_directory)
    if not os.path.isdir(embds_directory):
        os.makedirs(embds_directory)
    if not os.path.isdir(model_directory):
        os.makedirs(model_directory)
        os.makedirs(model_directory+"_freeze")
        os.makedirs(model_directory+"_prompt")
    no_features= args.no_features
    log('use {} data'.format(args.dataset))
    log('-'*40)
    
    if args.dataset == 'Moons':
        num_task=10
        data_size=2
        num_instances=220
    elif args.dataset == 'MNIST':
        num_task=11
        data_size=2
        num_instances=200
    elif args.dataset == 'ONP':
        num_task=6
        data_size=58
        num_instances=None
    elif args.dataset == 'Elec2':
        num_task=41
        data_size=8
        num_instances=None
    elif args.dataset == 'RC':
        num_task=21
        data_size=14
        num_instances=None
    elif args.dataset == 'WR':
        num_task=16
        data_size=12
        num_instances=None
    
    # Defining dataloaders for each domain
    if seq ==1:
        dataloaders = dp_sq(args, num_task, num_instances)
    else:
        dataloaders = dp(args, num_task, num_instances)
        
    for task_id, dataloader in enumerate(dataloaders[:-1]):
        backbone =  TransformerModel(1, no_features ,model_chan,no_features).cuda() 
        checkpoint = torch.load(model_directory+"_freeze/"+"model.pth")
        backbone.load_state_dict(checkpoint['model_state_dict'], strict=False)
        for name, para in backbone.named_parameters():
            if name !="learned_embedding" :
                para.requires_grad = False
        optimizer = torch.optim.Adam(backbone.parameters(), lr=args.lr)
        backbone.train()
        for epoch in range(args.epoches):
            for X, Y, mask in  dataloader:
                X, Y, mask  = X.float().to(device), Y.float().to(device) , mask.float().to(device) 
                optimizer.zero_grad()
                pred , PS= backbone(X)
                #Y=torch.cat([X.squeeze() ,Y.squeeze()],1)
                loss  = F.mse_loss(pred.squeeze()[:,96:], Y.squeeze() )
                #loss = F.mse_loss(pred.squeeze(), Y.squeeze())
                loss.backward()
                optimizer.step()  
            if epoch ==args.epoches-1  :
                    PS =PS[0].squeeze().cpu().detach().numpy()
                    np.savez_compressed( PS_directory+'/task'+str(task_id)+"_new", ps=PS)

            adjust_learning_rate(optimizer, epoch+1, args)
     
     


        

if __name__ == "__main__":

    print("Start Training...")
    timestr = time.strftime("%Y%m%d-%H%M%S")
    
    main(args)









