import torch
from utilsMetaDA import *
from data_handler import DataHandler

import datetime
import copy
import os
import pathlib
import random
import numpy as np


class SDSModel:

    def __init__(self, flags):
        # torch.set_default_tensor_type('torch.cuda.FloatTensor')
        self.configure(flags)
        self.seed_init()
        self.network_init(flags)



    def __del__(self):
        print('-------release source-------\n')

    def configure(self, flags):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.seed = flags.seed

        self.theta_layers = [310, 512, 256]
        self.phi_layers = [256, 100, 3]
        self.sum_d_layers = [256, 192, 128]


        self.log_path = os.path.join(flags.total_log_dir, flags.name) # different num recorded in different domain
        pathlib.Path(self.log_path).mkdir(parents=True, exist_ok=True)

        data = DataHandler(flags.num)
        self.src_data = data.src_data
        self.src_label = data.src_label
        self.trg_data = data.trg_data
        self.trg_label = data.trg_label


    def network_init(self, flags):
        self.phi = MLP(self.phi_layers).to(self.device)
        self.theta = LSTM(310, self.phi_layers[0], 2).to(self.device)

        self.ce_loss = torch.nn.CrossEntropyLoss()
        self.ce_loss_list = torch.nn.CrossEntropyLoss(reduce=False)
        # self.omega = Sum_decomposable(self.sum_d_layers).to(self.device)
        self.omega1 = DomainMetric(self.sum_d_layers).to(self.device)
        self.omega2 = DomainShift(self.sum_d_layers[-1]).to(self.device)

        self.opt_phi = torch.optim.Adam(self.phi.parameters(), lr=flags.lr, weight_decay=flags.weight_decay)
        self.opt_theta = torch.optim.Adam(self.theta.parameters(), lr=flags.lr, weight_decay=flags.weight_decay)
        # self.opt_omega = torch.optim.Adam(filter(lambda p : p.requires_grad, self.omega.parameters()), lr=flags.lr_omega, weight_decay=flags.weight_decay)
        # self.sch_omega = torch.optim.lr_scheduler.StepLR(self.opt_theta, step_size=100, gamma=0.2, last_epoch=-1)
        self.opt_omega1 = torch.optim.Adam(filter(lambda p : p.requires_grad, self.omega1.parameters()), lr=flags.lr_omega, weight_decay=flags.weight_decay)
        self.opt_omega2 = torch.optim.Adam(filter(lambda p : p.requires_grad, self.omega2.parameters()), lr=flags.lr_omega, weight_decay=flags.weight_decay)
        

    def seed_init(self):
        seed = self.seed
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.backends.cudnn.deterministic = True

    def pre_train(self, flags):
        for _ in range(2000):
            pre_train_loss_main = 0.0
            pre_train_acc = []
            for k in range(len(self.src_data)):
                meta_batch_index = batch_index_generator(self.src_data[k].shape[0], flags.batch_size)
                cur_data_tr = self.src_data[k][meta_batch_index]
                cur_label_tr = self.src_label[k][meta_batch_index]
                
                x_meta_train = torch.tensor(cur_data_tr, requires_grad=False).float().to(self.device)
                y_meta_train = torch.tensor(cur_label_tr, requires_grad=False).long().to(self.device)
                y_meta_train = torch.squeeze(y_meta_train)
                
                feat_a = self.theta(x_meta_train)
                pred_a = self.phi(feat_a)
                loss_main = self.ce_loss(pred_a, y_meta_train)
                pre_train_loss_main += loss_main
                pre_train_acc.append(self.naive_test(cur_data_tr, cur_label_tr))

            mean_acc = sum(pre_train_acc)/len(pre_train_acc)
                
            self.opt_phi.zero_grad()
            self.opt_theta.zero_grad()
            pre_train_loss_main.backward()
            self.opt_phi.step()
            self.opt_theta.step()
            if (_ + 1) % 5 == 0:
                acc_test = self.naive_test(trg_data=self.trg_data, trg_label=self.trg_label)
                print("--iteration: {}. test accuracy is {:.5f}%".format(_+1, acc_test))
                print("------------------ train accuracy is {:.5f}%".format(mean_acc))
                print("----------current main train loss is {}".format(pre_train_loss_main.item()))
            if mean_acc > 85:
                return       # quit pretrain process

    def find_omega(self, flags):
        record_acc(os.path.join(self.log_path, 'find_omega{}.csv'.format(flags.num)),
                        '{}'.format(flags.__dict__))
        for ite_fo in range(flags.ite_find_omega):
            if flags.dadg == 'dg':
                classifier_loss = self.dg_train(flags, ite_fo)
            elif flags.dadg == 'da':
                classifier_loss = self.da_train(flags, ite_fo)
            else:
                raise ValueError('Please set flags.dadg correctly! ')

            train_acc, acc_train_mean = self.train_test(flags)

            train_len = 10
            shuffle_idx = np.random.permutation(range(self.src_data.shape[0]))
            meta_train_data = self.src_data[shuffle_idx[:train_len]]
            meta_train_label = self.src_label[shuffle_idx[:train_len]]
            meta_test_data = self.src_data[shuffle_idx[train_len:]]
            meta_test_label = self.src_label[shuffle_idx[train_len:]]
            
                
            # meta train
            for ite_mt in range(flags.ite_omega_mt):
                omega_loss = 0.0
                coupled_loss = 0.0
                for k in range(meta_train_data.shape[0]):
                    meta_test_batch_index = batch_index_generator(meta_train_data[k].shape[0], flags.batch_size)

                    cur_data_mt = meta_train_data[k][meta_test_batch_index]
                    cur_label_mt = meta_train_label[k][meta_test_batch_index]
                    x_meta_test = torch.tensor(cur_data_mt, requires_grad=False).float().to(self.device)
                    y_meta_test = torch.tensor(cur_label_mt, requires_grad=False).long().to(self.device)
                    y_meta_test = torch.squeeze(y_meta_test)


                    coupled_loss += self.omega2(self.omega1(self.theta(x_meta_test)))

                self.opt_omega1.zero_grad()
                self.opt_omega2.zero_grad()
                self.opt_theta.zero_grad()
                coupled_loss.backward(create_graph=True, retain_graph=True)
            
                # use meta learning optimizing omega to have better performance in the unseen domain
                grad_theta = [theta_i.grad for theta_i in self.theta.parameters()]
                theta_updated = {}

                num_grad = 0
                for i, (k, v) in enumerate(self.theta.state_dict().items()):

                    if grad_theta[num_grad] is None:
                        num_grad += 1
                        theta_updated[k] = v
                    else:
                        theta_updated[k] = v - flags.lr_maml * grad_theta[num_grad]
                        num_grad += 1

                temp_theta = copy.deepcopy(self.theta)
                temp_theta = fix_nn(temp_theta, theta_updated)
#
                for k in range(meta_test_data.shape[0]):
                    meta_test_batch_index = batch_index_generator(meta_test_data[k].shape[0], flags.batch_size)

                    cur_data_mt = meta_test_data[k][meta_test_batch_index]
                    cur_label_mt = meta_test_label[k][meta_test_batch_index]
                    x_meta_test = torch.tensor(cur_data_mt, requires_grad=False).float().to(self.device)
                    y_meta_test = torch.tensor(cur_label_mt, requires_grad=False).long().to(self.device)
                    y_meta_test = torch.squeeze(y_meta_test)
#
                    y_pred_old = self.phi(self.theta(x_meta_test))
                    ce_loss_list_old = self.ce_loss_list(y_pred_old, y_meta_test)
                    y_pred_new = self.phi(temp_theta(x_meta_test))
                    ce_loss_list_new = self.ce_loss_list(y_pred_new, y_meta_test)
                    
                    reward = ce_loss_list_old - ce_loss_list_new  # the bigger, the better
                    omega_loss -= torch.tanh(reward * 2).sum()    # the smaller omega_loss, the better

                omega_loss.backward(retain_graph=False, create_graph=False)
                if ite_fo <= flags.max_ite_fo:
                    self.opt_omega1.step()
                self.opt_omega2.step()
                torch.cuda.empty_cache()

            if (ite_fo + 1) % 1 == 0:
                acc_test, loss_test = self.test(self.trg_data, self.trg_label, flags)
                record_acc(os.path.join(self.log_path, 'find_omega{}.csv'.format(flags.num)),
                        '{}'.format(flags.num),
                        '{}'.format(flags.i),
                        '{}'.format(self.seed),
                        '{}'.format(ite_fo),
                        '{}'.format(classifier_loss),
                        '{}'.format(omega_loss.item()),
                        '{}'.format(coupled_loss.item()),
                        '{:.5f}'.format(acc_train_mean),
                        '{}'.format(acc_test),
                        '{}'.format(loss_test))

            
           
    def da_train(self, flags, ite_fo): # It's more like domain adaptation

        total_loss = 0.0
        total_coupled_loss = 0.0
        self.opt_omega1.zero_grad()
        self.opt_omega2.zero_grad()
        

        for ite_i in range(len(self.src_data)):
            meta_batch_index = batch_index_generator(self.src_data[ite_i].shape[0], flags.batch_size)
            cur_data_tr = self.src_data[ite_i][meta_batch_index]
            cur_label_tr = self.src_label[ite_i][meta_batch_index]
            
            domain_x_i = torch.tensor(cur_data_tr, requires_grad=False).float().to(self.device)
            domain_y_i = torch.tensor(cur_label_tr, requires_grad=False).long().to(self.device)
            domain_y_i = torch.squeeze(domain_y_i)

            feature_x = self.theta(domain_x_i)
            coupled_loss = self.omega2(self.omega1(feature_x))
            ce_loss = self.ce_loss(self.phi(feature_x), domain_y_i)

            self.opt_theta.zero_grad()
            # self.opt_omega1.zero_grad()
            # self.opt_omega2.zero_grad()
            coupled_loss.backward(create_graph=True, retain_graph=True)
            
            grad_theta = [theta_i.grad for theta_i in self.theta.parameters()]
            theta_updated = {}

            num_grad = 0
            for i, (k, v) in enumerate(self.theta.state_dict().items()):

                if grad_theta[num_grad] is None:
                    num_grad += 1
                    theta_updated[k] = v
                else:
                    theta_updated[k] = v - flags.lr_maml * grad_theta[num_grad]
                    num_grad += 1

            temp_theta = copy.deepcopy(self.theta)
            temp_theta = fix_nn(temp_theta, theta_updated)
            
            meta_loss = self.ce_loss(self.phi(temp_theta(domain_x_i)), domain_y_i)
            total_loss += meta_loss * flags.pr_coupled + ce_loss
            total_coupled_loss += coupled_loss
            
        self.opt_theta.zero_grad()
        self.opt_phi.zero_grad()
        total_loss.backward()
        self.opt_theta.step()
        self.opt_phi.step()
        if ite_fo <= flags.max_ite_fo:
            self.opt_omega1.step()
        self.opt_omega2.step()
        torch.cuda.empty_cache()
        return total_loss.item()

    def dg_train(self, flags, ite_fo):
        train_len = 10
        shuffle_idx = np.random.permutation(range(self.src_data.shape[0]))
        meta_train_data = self.src_data[shuffle_idx[:train_len]]
        meta_train_label = self.src_label[shuffle_idx[:train_len]]
        meta_test_data = self.src_data[shuffle_idx[train_len:]]
        meta_test_label = self.src_label[shuffle_idx[train_len:]]
        coupled_loss = 0.0
        omega_loss = 0.0
        ce_loss = 0.0
            
        # meta train
        for ite_mt in range(flags.ite_omega_mt):
            omega_loss = 0.0
            for k in range(meta_train_data.shape[0]):
                meta_test_batch_index = batch_index_generator(meta_train_data[k].shape[0], flags.batch_size)

                cur_data_mt = meta_train_data[k][meta_test_batch_index]
                cur_label_mt = meta_train_label[k][meta_test_batch_index]
                x_meta_test = torch.tensor(cur_data_mt, requires_grad=False).float().to(self.device)
                y_meta_test = torch.tensor(cur_label_mt, requires_grad=False).long().to(self.device)
                y_meta_test = torch.squeeze(y_meta_test)
                coupled_loss += self.omega2(self.omega1(self.theta(x_meta_test)))
                ce_loss += self.ce_loss(self.phi(self.theta(x_meta_test)), y_meta_test)

            self.opt_omega1.zero_grad()
            self.opt_omega2.zero_grad()
            self.opt_theta.zero_grad()
            coupled_loss.backward(create_graph=True, retain_graph=True)
        
            # use meta learning optimizing theta to have better performance in the unseen domain
            grad_theta = [theta_i.grad for theta_i in self.theta.parameters()]
            theta_updated = {}

            num_grad = 0
            for i, (k, v) in enumerate(self.theta.state_dict().items()):

                if grad_theta[num_grad] is None:
                    num_grad += 1
                    theta_updated[k] = v
                else:
                    theta_updated[k] = v - flags.lr_maml * grad_theta[num_grad]
                    num_grad += 1

            temp_theta = copy.deepcopy(self.theta)
            temp_theta = fix_nn(temp_theta, theta_updated)

            for k in range(meta_test_data.shape[0]):
                meta_test_batch_index = batch_index_generator(meta_test_data[k].shape[0], flags.batch_size)

                cur_data_mt = meta_test_data[k][meta_test_batch_index]
                cur_label_mt = meta_test_label[k][meta_test_batch_index]
                x_meta_test = torch.tensor(cur_data_mt, requires_grad=False).float().to(self.device)
                y_meta_test = torch.tensor(cur_label_mt, requires_grad=False).long().to(self.device)
                y_meta_test = torch.squeeze(y_meta_test)

                y_pred_new = self.phi(temp_theta(x_meta_test))
                ce_loss += self.ce_loss(y_pred_new, y_meta_test)

            self.opt_theta.zero_grad()
            self.opt_phi.zero_grad()
            total_loss = ce_loss + coupled_loss * flags.pr_coupled / (ite_mt // 200 + 1)
            total_loss.backward(retain_graph=False, create_graph=False)
            if ite_fo <= flags.max_ite_fo:
                self.opt_omega1.step()
            self.opt_omega2.step()
            self.opt_phi.step()
            self.opt_theta.step()
            torch.cuda.empty_cache()
            return ce_loss.item()

    def my_train(self, flags):
        begin_time = datetime.datetime.now()
        self.pre_train(flags)
        self.find_omega(flags)
        end_time = datetime.datetime.now()
        acc_final, _ = self.test_test(flags)
        acc_final = acc_final[0]
        self.write_log(begin_time, end_time, acc_final, flags)
        if flags.save == 1:
            torch.save(self.theta, 'final_theta{}{}.pt'.format(flags.i,flags.num))
        return acc_final

    def test_test(self, flags):
        return self.test(self.trg_data, self.trg_label, flags)

    def train_test(self, flags):
        acc_list = []
        for k in range(len(self.src_data)):
            acc_train, _ = self.test(self.src_data[k], self.src_label[k], flags)
            acc_list.append(acc_train[1])
        mean_acc = sum(acc_list) / len(acc_list)
        return acc_list, mean_acc

    def naive_test(self, trg_data, trg_label):
        x_test = torch.tensor(trg_data, requires_grad=False).float().to(self.device)
        y_test = torch.tensor(trg_label, requires_grad=False).long().to(self.device)
        y_test = torch.squeeze(y_test)
        feat_test = self.theta(x_test)
        pred_test = self.phi(feat_test).argmax(1)
        acu = torch.sum(pred_test == y_test) / len(y_test)
        acu = acu.item() * 100

        return acu

    def test(self, x_data, y_data, flags):
        # this simulate real test, which requries feature extractor needs to be update under the govern of self.omega1\2
        x_test = torch.tensor(x_data, requires_grad=False).float().to(self.device)
        y_test = torch.tensor(y_data, requires_grad=False).long().to(self.device)
        y_test = torch.squeeze(y_test)

        temp_theta = copy.deepcopy(self.theta)
        temp_opt_theta = torch.optim.SGD(temp_theta.parameters(), lr = flags.lr_maml_test)
        acc_test_0 = self.naive_test(x_data, y_data)
        loss_test_0 = self.ce_loss(self.phi(self.theta(x_test)), y_test)
        loss_test = [loss_test_0.item()]
        acc_test = [acc_test_0]
        for i in range(flags.ite_test):
            coupled_loss = self.omega2(self.omega1(temp_theta(x_test)))
            temp_opt_theta.zero_grad()
            coupled_loss.backward()
            temp_opt_theta.step()
            feat_test = temp_theta(x_test)
            pred_test = self.phi(feat_test).argmax(1)
            acu = torch.sum(pred_test == y_test) / len(y_test)
            acu = acu.item() * 100

            loss = self.ce_loss(self.phi(feat_test), y_test).item()
            acc_test.append(acu)
            loss_test.append(loss)

            self.opt_theta.zero_grad()
            torch.cuda.empty_cache()
        return acc_test, loss_test

    def write_log(self, begin_time, end_time, acc_final, flags, log_name='log.txt'):
        log_path_name = os.path.join(flags.total_log_dir, log_name)
        log_end_time = 'Training process successfully! \n' \
                       'The beginning time is: ' + begin_time.strftime('%D -- %H:%M:%S') + ',  '+\
                       'The ending time is: ' + end_time.strftime('%D -- %H:%M:%S') + \
                       '.  {:.2f} minutes have been used for training.\n'.format((end_time - begin_time).seconds / 60) +\
                       'The final accuracy on the test set is {:.2f}%.\n'.format(acc_final) + \
                       'It\'s the {}th training, test domain is {}. \n'.format(flags.i, flags.num) +\
                       'flags are : {}'.format(flags.__dict__)

        with open(log_path_name, mode='a') as f:
            f.write(log_end_time + '\n')

        with open(os.path.join(self.log_path, 'find_omega{}.csv'.format(flags.num)), mode='a') as f:
            f.write('Training completed. Flags are : {}'.format(flags.__dict__) + '\n')

        print(log_end_time)
    


def main():
    flags = SetHyperParameter()
    flags.i = 0
    flags.seed = np.random.randint(2**31 - 1)
    flags.total_log_dir = os.path.join('logs','New_model')
    flags.log_path = os.path.join(flags.total_log_dir, str(flags.i))
    flags.lr = 0.0002
    flags.weight_decay = 0.0002
    flags.meta_train_len = 7
    flags.ite_find_omega = 20
    flags.ite_omega_mt = 50
    flags.ite_opt_theta_phi = 200
    flags.batch_size = 1024
    flags.lr_maml = 0.0002
    flags.lr_omega = 0.0002
    flags.ite_test = 5

    for _ in range(10):
        flags.i += 1
        flags.seed = np.random.randint(2**31 - 1)
        flags.log_path = os.path.join(flags.total_log_dir, str(flags.i))
        acc_list = []
        for num in range(15):
            flags.num = num
            A = SDSModel(flags)
            acc_list.append(A.my_train(flags))
        acc_mean = sum(acc_list) / len(acc_list)
        print(acc_mean, acc_list)
