# -*- coding: utf-8 -*-

import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from prettytable import PrettyTable
import math
import os
import logging
import time
import datetime
from tqdm import tqdm
import argparse
from torch import nn
import torch as th


from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup
from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType
from utils import  dp , dp_sq

# class PositionalEncoding(nn.Module):

#     def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
#         super().__init__()
#         self.dropout = nn.Dropout(p=dropout)

#         position = torch.arange(max_len).unsqueeze(1)
#         div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
#         pe = torch.zeros(max_len, 1, d_model)
#         pe[:, 0, 0::2] = torch.sin(position * div_term)
#         pe[:, 0, 1::2] = torch.cos(position * div_term)
#         self.register_buffer('pe', pe)

#     def forward(self, x):
#         """
#         Arguments:
#             x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
#         """
#         x = x + self.pe[:x.size(0)]
#         return self.dropout(x)

####################### To omit randomness ########################
def set_seed(seed = 4389572):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
set_seed(4389572)

############# Parameter counter #########################
def count_parameters(model):
        table = PrettyTable(["Modules", "Parameters"])
        total_params = 0
        for name, parameter in model.named_parameters():
                if not parameter.requires_grad:
                    continue
                params = parameter.numel()
                table.add_row([name, params])
                total_params+=params
        #print(table)
        print(f"Total Trainable Params: {total_params}")
        return total_params
                                                                

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        d_model= d_model

        position = th.arange(max_len).unsqueeze(1)
        div_term = th.exp(th.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = th.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = th.sin(position * div_term)
        pe[0, :, 1::2] = th.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
        
    x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        x = x + self.pe[0:1, :x.size(1)]
        return  self.dropout(x)
     

# setup logging
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
if not os.path.isdir('logs'):
    os.makedirs('logs')
log_file = 'logs/log_{}.log'.format(datetime.datetime.now().strftime("%Y_%B_%d_%I-%M-%S%p"))
open(log_file, 'a').close()

# create logger
logger = logging.getLogger('main')
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')

# create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(ch)

# add to log file
fh = logging.FileHandler(log_file)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
logger.addHandler(fh)


def log(str): logger.info(str)
log('Is GPU available? {}'.format(torch.cuda.is_available()))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser(description="DomainGen_Graph")
datasets = ['ONP', 'Moons', 'MNIST', 'Elec2']
parser.add_argument("--prompt", default=0, type=int,
                    help="if prompting applied")
parser.add_argument("--no_features", default=1, type=int,
                    help="Numbder of features.")
parser.add_argument("--gen", default=0, type=int,
                    help="prompt generatin .")
parser.add_argument("--dataset", default="Moons", type=str,
                    help="one of: {}".format(", ".join(sorted(datasets))))

# Hyper-parameters
parser.add_argument("--noise_dim", default=16, type=float,
                    help="the dimension of the LSTM input noise.")
parser.add_argument("--num_rnn_layer", default=1, type=float,
                    help="the number of RNN hierarchical layers.")
parser.add_argument("--latent_dim", default=16, type=float,
                    help="the latent dimension of RNN variables.")

parser.add_argument("--noise_type", choices=["Gaussian", "Uniform"], default="Gaussian",
                    help="The noise type to feed into the generator.")
parser.add_argument("--num_workers", default=0, type=int,
                    help="the number of threads for loading data.")
parser.add_argument("--epoches", default=20  , type=int,
                    help="the number of epoches for each task.")
parser.add_argument("--batch_size", default=128, type=int,
                    help="the number of epoches for each task.")
parser.add_argument("--lr", default=0.0001, type=float,
                    help="the unified learning rate for each single task.")
parser.add_argument("--num_task", default=8, type=int,
                    help="number of tasks")
parser.add_argument("--is_test", default=True, type=bool,
                    help="if this is a testing period.")
parser.add_argument("--model_channel", default=128, type=int,
                    help="num model channel" )
parser.add_argument("--hidden_dim", default= 256, type=int,
                    help="num hidden dim .")

parser.add_argument("--emb_dim_p", default= 256, type=int,
                    help="emb dim p .")

parser.add_argument("--seq", default= 1, type=int,
                    help="sequentional training or not.")
parser.add_argument("--num_layers", default=2, type=int,
                    help="num hidden dim .")                 
parser.add_argument("--e_type", default="max", type=str,
                    help="e type .")                 
parser.add_argument("--num_layers_p", default=2, type=int,
                    help="num hidden dim .")      

parser.add_argument("--num_layers_m", default=2, type=int,
                    help="num hidden dim .")   

parser.add_argument("--hp", default=2, type=int,
                    help="num hidden dim prompt .")   
parser.add_argument("--ag", default=0.1, type=float,
                    help="num hidden dim prompt .")   




args = parser.parse_args()
hidden_dim=args.hidden_dim
num_layers_p=args.num_layers_p
num_layers_m=args.num_layers_m
emb_dimp =args.emb_dim_p

e_type= args.e_type
prompting = True if args.prompt!=0 else False 

hp=args.hp
seq = args.seq
print("Seq is", seq)

#print("promtping is", prompting)

if args.gen==0:
    generatirng_prompts= False
else:
    generatirng_prompts= True

if not generatirng_prompts:
    model_chan =args.model_channel
else:
    model_chan=emb_dimp
#print(args.prompt, args.no_features, args.dataset, generatirng_prompts)
dataset =args.dataset
batch_size =args.batch_size

def adjust_learning_rate(optimizer, epoch, args):
    # lr = args.learning_rate * (0.2 ** (epoch // 2))
    
    lr_adjust = {epoch: args.lr * (0.9 ** ((epoch-1) // 1))}
    # elif args.lradj=='type2':
    #     lr_adjust = {
    #         2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 
    #         10: 5e-7, 15: 1e-7, 20: 5e-8
    #     }
    if epoch in lr_adjust.keys():
        lr = lr_adjust[epoch]
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        #print('Updating learning rate to {}'.format(lr))



def attention(q, k, v, d_k, mask=None, dropout=None):
    #print(q.shape, k.shape, v.shape)
    scores = th.matmul(q, k.transpose(-2, -1)) /  math.sqrt(d_k)
    if mask is not None:
        #mask = mask.unsqueeze(1)
        #print(scores.shape, mask.shape)
        scores = scores.masked_fill(mask == 0, -1e9)
        #print(scores.shape, "Sc")
    scores = F.softmax(scores, dim=-1)
    if dropout is not None:
        scores = dropout(scores)
    output = th.matmul(scores, v)
    #print(output.shape, "out")
    return output

class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout = 0.01):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)
    def forward(self, q, k, v, mask=None):
        bs = q.size(0)
        # perform linear operation and split into h heads
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        # transpose to get dimensions bs * h * sl * d_model
        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)# calculate attention using function we will define next
        scores = attention(q, k, v, self.d_k, mask, self.dropout)
        # concatenate heads and put through final linear layer
        concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        return output


class EncoderLayer(nn.Module):
    def __init__(self, d_model, heads, dropout,hidden_dim, activation):
        super().__init__()
        self.norm_1 = nn.InstanceNorm1d(d_model)
        self.norm_2 = nn.InstanceNorm1d(d_model)
        self.gen_attn = MultiHeadAttention(heads, d_model)
        self.ff = FeedForward(d_model, d_model, dropout, activation)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x2 = self.norm_1(x)
        x = x  +self.dropout(self.gen_attn(x2, x2, x2))
        x2 = self.norm_2(x)
        x = x + self.dropout(self.ff(x2))
        return x


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout, activation):
        super().__init__() 
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
        self.activation = activation
    def forward(self, x):
        x = self.dropout(self.activation(self.linear_1(x)))
        #x = self.linear_2(x)
        return x



class TransformerModel_old(nn.Module):
    """
    The full Transformer model with timestep embedding.
    """

    #backbone = TransformerModel(1, 8,128,1

    def __init__(
        self,
        num_feature,
        in_channels,
        model_channels,
        out_channels,
   
    ):
        super().__init__()
        self.in_channels =   in_channels
        self.model_channels =  model_channels
        self.out_channels = out_channels
        self._feature_size = self.model_channels
        self.input_emb = nn.Linear(self.in_channels, self.model_channels)
        self.activation = nn.GELU()
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(self.model_channels, 1, hidden_dim, 0.01, self.activation, batch_first=True), num_layers_m)
        self.output_linear2 = nn.Linear(self.model_channels, self.out_channels)

        
    def forward(self, x  ,out_p =None, task_id=0):
        x=x.float()#.unsqueeze(2)
        x2=torch.zeros_like(x)
        x=torch.cat([x,x2],1)  
        he = self.input_emb(x)
        out_p=out_p.repeat(x.shape[0],1, 1)
        he=  torch.cat([out_p, he],1)
        he=self.transformer(he)
        he = he[:, :-2, :]
        out = self.output_linear2(he)
        return out



class TransformerModel(nn.Module):
    def __init__(
        self,
        num_feature,
        in_channels,
        model_channels,
        out_channels  
        ):
        super().__init__()
        self.in_channels =   in_channels
        self.model_channels =  model_channels
        self.out_channels = in_channels
        self.pos_encoder=PositionalEncoding(model_channels*2, 0.001)
        self._feature_size = self.model_channels
        self.input_emb = nn.Linear(self.in_channels, self.model_channels)
        self.input_emb2 = nn.Linear(self.model_channels, self.model_channels*2)
        self.activation = nn.GELU()
        
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(self.model_channels*2, 1, hidden_dim, 0.1, self.activation, batch_first=True), num_layers_m)
        self.output_linear2 = nn.Linear(self.model_channels, self.out_channels)
        self.output_linear1 = nn.Linear(self.model_channels*2, self.model_channels)

        # w = torch.empty(1,model_channels*2)
        # self.learned_embedding = nn.parameter.Parameter(nn.init.xavier_uniform_(w,gain =4))

    def forward(self, x  , out_p=None, task_id=0):
        x=x.float().unsqueeze(2)
        #P_S =  self.learned_embedding.repeat(x.shape[0], 1, 1)
        x2=torch.zeros_like(x)
        x=torch.cat([x,x2],1)  
        he = self.input_emb(x)
        he = self.input_emb2(he)
        out_p=out_p.repeat(x.shape[0],1, 1)
        he = torch.cat([ he, out_p],1)
        he = self.transformer(self.pos_encoder(he))
        he=self.output_linear1(he[:,:-2])
        out = self.output_linear2(he)
        return out


class prompt_model(nn.Module):
    """
    The full Transformer model with timestep embedding.
    """

    #backbone = TransformerModel(1, 8,128,1

    def __init__(
        self,
        num_feature,
        in_channels,
        model_channels,
        out_channels,
   
    ):
        super().__init__()
       
        self.num_points =num_feature
        self.in_channels =  model_channels #in_channels
        self.model_channels = model_channels*2
        self._feature_size = self.model_channels
        self.activation = nn.GELU()
        self.pos_encoder=PositionalEncoding(model_channels*2, 0.001)

        #self.transformer = nn.ModuleList([EncoderLayer(self.model_channels, 1, 0.01,hp, self.activation) for x in range (num_layers_m)])
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(self.model_channels, 1, hp, args.ag, self.activation, batch_first=True), num_layers_p)
        w = torch.empty(1, 1,self.model_channels)
        self.PS_now = nn.parameter.Parameter(nn.init.xavier_uniform_(w))
        w2 = torch.empty(1,1,model_chan*2)
        self.learned_embedding = nn.parameter.Parameter(nn.init.xavier_uniform_(w2))



    def init_prompt(self, new_prompt):
        self.PS_now.data= new_prompt.data

    def forward(self, X, PS_tmp, task_id):
        # ps=self.PS_now
        #import pdb ; pdb.se
        P_T = self.transformer(self.pos_encoder(PS_tmp))
        #import pdb  ; pdb.set_trace()
        PT= P_T[:,task_id,:].unsqueeze(1)
        #learned_embedding =self.learned_embedding.repeat(X.shape[0],1,1)
        out=  torch.cat([PT, self.learned_embedding], 1)
        return out ,self.PS_now



       

def evaluation(dataloader, backbone,arsgs, PS_total, prompt_unit):
    backbone.eval()
    prompt_unit.eval()
    prompt_unit.init_prompt(PS_total)
    mses = []
    for X, Y, mask in dataloader:   
        X, Y , mask = X.float().to("cuda"), Y.float().to("cuda") ,mask.float().to("cuda")
        with torch.no_grad(): 
            # mask=mask[:,1:]
            # ps=PS_total.squeeze() 
            # aa=torch.mm(mask,ps)      
            # PS_total_new=aa
            #out_p , _= prompt_unit(X, PS_total_new, -1)
            out_p,_ = prompt_unit(X, PS_total, -1)
            pred = backbone(X ,out_p)
            mse = F.mse_loss(pred.squeeze()[:,96:], Y.squeeze() )
            mses.append((mse))
    log("Average MSE is :{}".format(sum(mses)/len(mses)))



def main(arsgs):

    PS_directory='PS/PS-{}-ch-{}-dim-{}-lap-{}-lam{}-embp-{}-hp-{}'.format(args.dataset, args.model_channel, hidden_dim, args.num_layers_p, args.num_layers_m, emb_dimp, hp)
    output_directory='outputs/outputs-{}-ch-{}-dim-{}-lap-{}-lam{}-embp-{}-hp-{}'.format(args.dataset, args.model_channel, hidden_dim, args.num_layers_p, args.num_layers_m, emb_dimp, hp)
    model_directory='models/models-{}-ch-{}-dim-{}-lap-{}-lam{}-embp-{}-hp-{}'.format(args.dataset, args.model_channel, hidden_dim, args.num_layers_p, args.num_layers_m, emb_dimp, hp)
    embds_directory='embds/embds-{}-ch-{}-dim-{}-lap-{}-lam{}-embp-{}-hp-{}'.format(args.dataset, args.model_channel, hidden_dim, args.num_layers_p, args.num_layers_m, emb_dimp, hp)
    if not os.path.isdir(output_directory):
        os.makedirs(output_directory)
    if not os.path.isdir(embds_directory):
        os.makedirs(embds_directory)
    if not os.path.isdir(model_directory):
        os.makedirs(model_directory)
        os.makedirs(model_directory+"_freeze")
        os.makedirs(model_directory+"_prompt")
    no_features= args.no_features
    log('use {} data'.format(args.dataset))
    log('-'*40)
    
    if args.dataset == 'Moons':
        num_task=10
        data_size=2
        num_instances=220
    elif args.dataset == 'MNIST':
        num_task=11
        data_size=2
        num_instances=200
    elif args.dataset == 'ONP':
        num_task=6
        data_size=58
        num_instances=None
    elif args.dataset == 'Elec2':
        num_task=41
        data_size=8
        num_instances=None
    elif args.dataset == 'RC':
        num_task=21
        data_size=14
        num_instances=None
    elif args.dataset == 'WR':
        num_task=16
        data_size=12
        num_instances=None
        
    
    # Defining dataloaders for each domain
    if seq ==1:
        dataloaders = dp_sq(args, num_task, num_instances)
    else:
        dataloaders = dp(args, num_task, num_instances)
    backbone =  TransformerModel(1, no_features ,model_chan,no_features).cuda() 
    if prompting :
        checkpoint = torch.load(model_directory+"_freeze/"+"model.pth")
        backbone.load_state_dict(checkpoint['model_state_dict'], strict=True)
        for name, para in backbone.named_parameters():
            #print(name)
            para.requires_grad = False
    count_parameters(backbone)
    #loss2= nn.L1Loss()
    if True:
       
        PS_total=[]
        prompt_unit = prompt_model(num_task, emb_dimp,model_chan,model_chan).cuda()
        count_parameters(prompt_unit)
        optimizer2 = torch.optim.Adam(prompt_unit.parameters(), lr=args.lr)
        for i in range(9):
            PS_d = np.load(PS_directory+'/task'+str(i)+'_'+'new'+'.npz')
            PS=PS_d['ps']
            PS_total.append(PS)
        PS_total= th.tensor(PS_total).to(device)
        PS_total=PS_total.squeeze()
        PS_total =PS_total.unsqueeze(0)
        backbone.train()
        prompt_unit.train()
        criterion = nn.MSELoss()
        for task_id, dataloader in enumerate(dataloaders[:-1]):
            PS_total_new = PS_total[:, :task_id,:]
            prompt_unit.init_prompt(PS_total_new)
            if task_id ==0:
                continue
            else :
                
                tk =task_id-1
            
            for epoch in range(args.epoches):
                maes =[]
                for X, Y, mask in  dataloader:
                    X, Y, mask  = X.float().to(device), Y.float().to(device) , mask.float().to(device)
                    # mask=mask[:,1:]
                    # ps=PS_total.squeeze() 
                    # aa=torch.mm(mask,ps)      
                    # PS_total_new=aa
                    out_p , _= prompt_unit(X, PS_total_new, tk)
                    optimizer2.zero_grad()
                    pred = backbone(X, out_p)

                    #loss=loss2(pred.squeeze(), Y.squeeze())
                    #Y=torch.cat([X.squeeze() ,Y.squeeze()],1)
                    # loss = criterion(pred.squeeze()[:,:], Y.squeeze() )
                    loss = F.mse_loss(pred.squeeze()[:,96:], Y.squeeze())
                   
                    loss.backward()
                    
                    
                    optimizer2.step() 
                    maes.append(loss.item())

                adjust_learning_rate(optimizer2, epoch+1, args)
            #print(sum(maes)/len(maes))

    evaluation(dataloaders[-1], backbone, arsgs ,PS_total, prompt_unit)



        

if __name__ == "__main__":

    print("Start Training...")
    timestr = time.strftime("%Y%m%d-%H%M%S")
    
    main(args)









