from .brokers_export import BrokerExport
from flask_restful import Resource, request
from flask_jwt_extended import jwt_required
from dateutil.parser import parse
from .. import pusher_client
from ..security_service import verify_role
from ..models import BrokersConnections, Portfolios
from ..responses import ErrorResponses, SuccessResponses
from ..user.users_params import *
import itertools
from ..import_params import ImportParams
from ..trade.trades_regroups import *
from ..server_info import ServerInfo
from ..event.events import UserEventList as EventUser
import hmac
import base64
import hashlib
import requests
import time as tt
from datetime import datetime as dt
import datetime as dtt

class OkxExport(BrokerExport,ServerInfo, Resource):

    @verify_role('broker export')
    def post(self):        
        try:
            """
            self.data = {"portfolio":"","broker":"tda","auto_login":true,"partial":true}
            """
            version = {}
            self.data = {}
            self.type_category=[]
            self.orders_execid = []
            self.data = request.json
            self.orders_filerow = []
            # Start synchronization and update progress
            new_event_id=self.init_sync()
            self._update_progress('Saving Pre-Data', '1%')
            if not isinstance(new_event_id,UserEvents) and not self.is_lambda():
                return Response.response(new_event_id['status'], out=new_event_id['export'], message=new_event_id['message'], title=new_event_id['title'])
            
            self.orders_filerow = ImportParams.filerow_orders(self.params['get_session_userid'],
                                                                        var_sync='tradeId', 
                                                                        var_file='ordId',
                                                                        broker='okx'
                                                                    )
            for credential in self.credentials:
                self._update_progress('Checking Credentials', '5%')
                self.credential = credential
                self.not_error = self.validate()                
                if self.not_error != True:
                    _end = self.end_sync()
                    if _end != True:
                        return _end
            
                self.symbols = []
                self.credential = None
                self.credential = credential
                self._update_progress('Login', '6%')
                self.not_error = self.login()
                if self.not_error != True:
                    _end = self.end_sync(code=1)
                    if _end != True:
                        return _end
                self._update_progress('Connecting to the server', '7%')    
                if self.credential:
                    self.symbols =  self.symbols +self.credential.symbols_list

                
                if self.data['symbols']!=[] and self.symbols == [] and ('add' in self.data or self.data['add'] == True):
                    self.symbols =self.data['symbols']

                try:
                    self.account_pusher = 'Account #'+str(self.get_api_key_information())
                except:
                    self.account_pusher = ''
                self.time_out_lambda=False
                if self.condition_lambda() and not self.aws_autosync:
                    try:
                        self.not_error = self.call_lambda()
                    except Exception as err:
                        #print(err)
                        #print(traceback.print_exc())
                        self.not_error = err
                        self.executed_lambda = False

                if not self.executed_lambda:
                    if self.symbols == [] and not self.add:
                        self._update_progress('Getting pairs', '17%')
                        self.get_symbols()
                    else:
                        if 'look_pairs' in self.data:
                            if not self.add and (self.data.get('look_pairs', False) == True or self.data.get('new_pairs', False) == True):
                                self._update_progress('Getting pairs', '17%')
                                self.get_symbols()
                    
                    self.symbols = list(set(self.symbols))
                 
                    self._update_progress('Saving data', '18%') 
                    self.save_configs()
                    self._update_progress('Retrieving orders', '19%')
                    
                    self.not_error = self.loop_symbols() 
                    if self.not_error != True:
                        _end = self.end_sync(code=7)
                        if _end != True:
                            return _end
                    self.orders=list(itertools.chain.from_iterable(self.orders))                                       
                    
                    for order in self.orders:                        
                        order['created_at'] =  int(order['cTime'])/1000
                        order['created_at_formated'] =  dt.utcfromtimestamp(int(order['cTime'])/1000).strftime('%Y-%m-%d %H:%M:%S') 

                    self.orders = list(sorted(self.orders, key=lambda i: i['created_at'], reverse=True))
                    #if self.orders:
                    #    self.orders_filerow = ImportParams.filerow_orders(self.params['get_session_userid'],
                    #                                                    var_sync='tradeId', 
                    #                                                    var_file='ordId',
                    #                                                    broker='okx'
                    #                                                )
                    
                    self.all_orders = self.orders 
                         
                    if self.get_types() !=[]:
                        for t in self.get_types():      
                            self.save_account(self.api_key_information,category=t)
                    else:
                        self.save_account(self.api_key_information,category='FUTURE') 
                    self._update_progress('Saving data', '19%')
                    self.save_files(file_type='json')
                    
                    # Process orders
                    self._update_progress('Reading orders', '19%')
                    self.orders_filerow = []
                    self.orders_filerow = ImportParams.filerow_orders(self.params['get_session_userid'],
                                                                    var_sync='tradeId', 
                                                                    var_file='ordId',
                                                                    broker='okx'
                                                                )
                    self.not_error  =  self.interpret_save_orders()    
                    if self.not_error != True:
                        self.not_error = str(self.not_error)
                        _end = self.end_sync(code=0)
                        if _end != True:
                            return _end
                   
                    self.user.users_configs.tda_import = True
                    self.user.users_configs.save()   
                    self.not_error  = self.save_params()
                    self.sync_status(healthy=True, message=self.out['mgs'] if 'mgs' in self.out else 'Sync Success', status_code=200)
                   
 
            #TERMINA LA ITERACION
            return self.end_sync(event_id=new_event_id)

        except Exception as err:
            if self.is_lambda() and self.can_call_lambda():      
                self.save_params_lambda((int(tt.time())-self.star_timestamp), 500, 'Fail')

            self.unlock_portfolios()
            self._update_progress('warning', 'Imported fail',True)            
            return ErrorResponses.error_500(self.data, err, version)

    def validate(self):
        try:

            self.data = request.json
            self.user.users_configs.not_warning_import = self.data.get('not_warning_import',True)
            self.user.users_configs.save()
            self.symbols = self.data.get('symbols',self.symbols)
            self.auto_import = self.data['auto_import'] if 'auto_import' in self.data else False
            self.host  = "https://www.okx.com" if  self.data['test_account']  else "https://www.okx.com"
            self.api_key_information=''
            self.new_sync = True
            if self.credential and (not 'add' in self.data or self.data['add'] == False):
                self.api_key = self.credential.api_key
                self.secret_key = self.credential.secret_key
                self.passphrase = self.credential.passphrase
                self.test_account =  self.credential.test_account 
                self.spot =  self.credential.spot
                self.future1 =  self.credential.future1
                self.future2 =  self.credential.future2
                self.option =  self.credential.option
                self.margin1 = self.credential.margin1
                self.new_pairs =  self.credential.new_pairs
                self.new_sync = False
            else:                
                if self.data['api_key'] =='' or self.data['secret_key'] == '' or self.data['passphrase'] == '' or self.data['test_account']=='':
                    return "Wrong credentials"
                
                account_check = BrokersConnections.find_by(**{'user_id':self.user.id, "api_key":self.data['api_key'].strip(), "secret_key":self.data['secret_key'].strip(), 
                                                        "passphrase":self.data['passphrase'].strip(), "active": True})
                
                
                self.api_key = self.data['api_key'].strip()
                self.secret_key = self.data['secret_key'].strip()
                self.passphrase = self.data['passphrase'].strip() 
                self.test_account =  self.data['test_account']
                self.spot =  self.data['spot'] if 'spot' in self.data and self.data['spot'] != "" else False
                self.future1 =  self.data['future1'] if 'future1' in self.data and self.data['future1'] != "" else False
                self.future2 =  self.data['future2'] if 'future2' in self.data and self.data['future2'] != "" else False
                self.option =  self.data['option'] if 'option' in self.data and self.data['option'] != "" else False
                self.margin1 =  self.data['margin1'] if 'margin1' in self.data and self.data['margin1'] != "" else False
                self.new_pairs =  self.data['new_pairs'] if 'new_pairs' in self.data  and self.data['new_pairs'] != "" else False
                self.new_sync = True
                self.add = self.data.get('add',False)
                if account_check:
                    return "account_is_already_connected"
                
            if not any([self.spot, self.future1, self.margin1, self.future2, self.option]):
                return 'select_type'
            return True

        except Exception as err:
            if not self.credential:
                return "Wrong credentials"
            self.api_key = self.credential.api_key
            self.secret_key = self.credential.secret_key
            self.passphrase = self.credential.passphrase
            self.test_account =  self.credential.test_account
            self.spot =  self.credential.spot
            self.future1 =  self.credential.future1
            self.future2 =  self.credential.future2
            self.option =  self.credential.option
            self.margin1 = self.credential.margin1
            self.new_pairs =  self.credential.new_pairs
            self.api_key_information=''
            self.host  = "https://www.okx.com" if  self.data['test_account']  else "https://www.okx.com"
            self.user.users_configs.not_warning_import = True
            self.user.users_configs.save()
            self.data = None
                     
    def parse_params_to_str(self,params):
        url = '?'
        for key, value in params.items():
            if(value != ''):
                url = url + str(key) + '=' + str(value) + '&'
        url = url[0:-1]        
        return url
            
    def get_timestamp(self):
        now = dtt.datetime.utcnow()
        t = now.isoformat("T", "milliseconds")
        return t + "Z"
    
    def signature(self, timestamp,  request_path, body):
        if str(body) == '{}' or str(body) == 'None':
            body = ''
            
        message = str(timestamp) + 'GET' + request_path # + str(body)
        mac = hmac.new(bytes(self.secret_key, encoding='utf8'), bytes(message, encoding='utf-8'), digestmod='sha256')
        d = mac.digest()

        return base64.b64encode(d)
 
    def request_okx(self, request_path, params):
        timestamp = self.get_timestamp()
        request_path = request_path + self.parse_params_to_str(params)  
        body = json.dumps(params)  
        sign = self.signature( timestamp,  request_path,  body)   
        header = dict()
        header['Content-Type'] = 'application/json'
        header['OK-ACCESS-KEY'] = self.api_key
        header['OK-ACCESS-SIGN'] = sign
        header['OK-ACCESS-TIMESTAMP'] = str(timestamp)
        header['OK-ACCESS-PASSPHRASE'] = self.passphrase 
        response =  self.session_v5.request("GET", self.host + request_path , headers=header)   
        if not str(response.status_code).startswith('2'):
            try:
                json_res = response.json()      
                return json_res                     
            except ValueError:
                return 'Invalid JSON error message from Okx: {}'.format(response.text)
        return response.json()        
                 
    def get_api_key_information(self):
        url='/api/v5/account/config'
        payload={}
        response = self.request_okx(url, payload)
        if  response["code"] == '0' and response["data"]!= []:
            try:
                self.api_key_information=response["data"][0]["uid"] 
                return response["data"][0]["uid"] 
            except:
                return 'Error getting User ID'
        else:
            return 'Error getting User ID'
        
    def login(self):  
        self.session_v5=requests.Session()          
        try:            
            url='/api/v5/account/config'
            payload={}   
            response = self.request_okx(url, payload)
            if  response["code"] == '0' and response["data"]!= []:
                self.api_key_information=response["data"][0]["uid"] 
                return True  
            else:
                if response['msg']=='Invalid OK-ACCESS-KEY':
                    return 'Invalid Api Key/Secret Key/Passphrase'
                else:
                    return  response['msg'] 
                
        except:
            return 'Error when making the request.' 
        
    def get_symbols(self):
        self.symbols = []
        url='/api/v5/public/instruments'        
        types = ["SPOT","MARGIN","SWAP","FUTURES","OPTION"] #SPOT,MARGIN,SWAP,FUTURES,OPTION
        for type in types:            
            payload={
                "instType": type
                }
            response = self.request_okx(url, payload)            
            if  response["code"] == '0' and response["data"]!= []:
                for symbol in response['data']:
                    self.symbols.append(symbol['instFamily'].replace('-', ''))
            #else:
            #   print(type, response) 
              
    def get_types(self):
        types = []
        if self.spot:
            types.append('SPOT')
        if self.future1:
            types.append('FUTURES')
        if self.future2:
            types.append('SWAP')
        if self.option:
            types.append('OPTION')
        if self.margin1:
            types.append('MARGIN')
        return types
    def loop_symbols(self):
        try:
            get_orders_event = EventUser.post_user_events(100, new_user=self.user_id)
        except:
            pass

        self.orders = []
        self.all_orders = []
        order = None
        f = 1
        start_time = 0
        symbol_count = 0
        paginated = False
        self.symbols.sort()
        types = ["SPOT","MARGIN","SWAP","FUTURES","OPTION"] #SPOT,MARGIN,SWAP,FUTURES,OPTION
        total= len(types)  
                
        for t in types:   
            if self.data['partial']== True:
                start_time = int(self.get_last_date_a(user_account=self.api_key_information, category=t)*1000) #-86400 menos un dia
                if start_time>0:
                    paginated = True
                else:
                    paginated = False
                    start_time = 0 
            else:
                paginated = False 
                start_time = 0 
            p = round((f/total)*100)   
            self._update_progress(self.account_pusher, 'Extracting orders - {}%'.format(p) )
            try:
                self.get_okx_orders_new(t, start_time, paginated)                          
            except Exception as e:
                #print(e)
                #print(traceback.print_exc())
                pass
            f = f +1
            symbol_count = symbol_count+1
        
        try:
            event_success = EventUser.update_user_events(get_orders_event)
        except:
            pass
        return True
 
    def get_okx_orders_new(self, _type,  start_time, paginated):
        url='/api/v5/trade/orders-history-archive' 
        msg_order_len = 0                               
        end_time = int(tt.time()*1000)       
        order_len = 100
        i = 1
        symbols=self.data['symbols'] 
        if paginated==True:
            payload={
                "instType": _type,
                "begin":  str(start_time),
                "limit": "100"
                }
        else:
            payload={
                "instType": _type,
                "end":  str(end_time),
                "limit": "100"
                }          
        while order_len == 100:                      
            try:
                tt.sleep(0.1)
                order = self.request_okx(url, payload)
                if  order['code'] == '0'  and ('data' in order and order['data'] != [] ):
                    order_len = len(order['data'])     
                    for row_order in order['data']:
                        if row_order['instType'] not in self.get_types():
                            continue
                        SymbolList=row_order['instId'].split("-")
                        symbol=SymbolList[0]+SymbolList[1]
                        if symbol in symbols or not symbols:
                            row_order['symbol']=symbol
                            if row_order['instType']=='FUTURES':
                                row_order['category']='FUTURE'
                            else:
                                row_order['category']=row_order['instType']
                            self.type_category.append(row_order['category'])    
                            self.used_symbols.append(symbol)
                            self.orders.append([row_order])
                        ordId =  row_order['ordId']
                    if order_len==100:
                        if paginated==True:
                            payload={
                                "instType": _type,
                                "begin":  str(start_time),
                                "limit": "100",
                                "before": ordId
                                }
                        else:
                            payload={
                                "instType": _type,
                                "end":  str(end_time),
                                "limit": "100",
                                "before": ordId
                                } 
                else:
                    order_len = 0
            except Exception as e:
                order_len = 0
            
            i = i+1
            msg_order_len = order_len + msg_order_len
                           
    def interpret_orders(self):
        #print('interpreter') 
        result_b = []
        b=len(self.orders)
        new_date_time = []
        pip_value_order = dict()
        pip_value_order_btc = dict()
        verify_njson_len = 0
        
        verify_njson_len = 0 
        i = 0
        
        new_list = []
        for order in self.orders:
                         
            check_json = json.dumps(order)
            check_json = hashlib.md5(check_json.encode('utf-8')).hexdigest()
            if check_json in new_list:
                continue
                
            i = i +1
            original_file_row = json.loads(json.dumps(order))
            njson = json.dumps(order)
            njson = hashlib.md5(njson.encode('utf-8')).hexdigest()
                
            if 'side'in order:
                action = order['side'].upper() if 'side' in order else order['transaction category'].upper() if 'transaction category' in order else ''
            else:
                action = 'BUY' if order['isBuyer']== True else 'SELL'
                
            action = 'BUY' if action == 'BUY' or action=='COVER' else 'SELL' if action == 'SELL' or action == 'SHORT' else ''
            
            order['date'] =  int(order['cTime'] )
            order['date'] = dt.utcfromtimestamp(order['date']/1000).strftime('%Y-%m-%d %H:%M:%S')
            order['type'] = 'TRADE'

            # continue
            
            order['price'] = order['fillPx']      
            try:          
                if float(order['fillPx']) != 0:
                    order['quantity'] = float(order['fillSz']) /  float(order['fillPx'])
                else:
                    continue 
            except:
                continue
            
            if not order['symbol'] or not order['date'] or not order['price'] or not action or order['type'] != "TRADE":
                continue

            pip_value = 1
            date_time = order['date']
            original_file_row['date_tz'] = date_time
            self.any_error, date, time = ImportParams.get_param_datetime(date_time,b)
            new_date_time = ImportParams.convert_date(date, time, self.date_format,True,'UTC')

            order['date'] = new_date_time[0]
            order['time'] = new_date_time[1]
            fp = str(order['price']).replace(',', '').replace('$','')
            decimal = fp[::-1].find('.')
            decimal = decimal if decimal > 1 else 2
            price = round(float(fp),decimal)
            
            sm = len(order['symbol'])
            type = 'crypto'
            option = 'CRYPTO'
            strike = ''
            expire = ''

            if order['category'] in self.params['user_portfolios']:
                user_portfolio = self.params['user_portfolios'][order['category']]
            else:
                user_portfolio = self.params['user_portfolio']           

            #################### VERIFY FILE ROW ######################
            if njson in self.orders_filerow:
                verify_njson_len = verify_njson_len + 1
                continue
            if 'tradeId' in original_file_row and original_file_row['tradeId']:
                order_id = '{}'.format(original_file_row['tradeId'])
                if order_id in self.orders_filerow:
                    verify_njson_len = verify_njson_len + 1
                    continue
            try:
                var_date = original_file_row['date_tz']
                var_date = ImportParams.parse_date(var_date).strftime('%Y-%m-%d %H:%M:%S')
                njson3 = '{}{}{}{}{}{}{}'.format(float(fp) if fp else 0.00,
                                                 var_date,
                                                 option,
                                                 1 if action == 'BUY' else 2,
                                                 float(str(order['sz']).replace(',', '')) if 'sz' in order \
                                                      else '',
                                                 float(strike) if ImportParams.isfloat(strike) else 0.0,
                                                 expire
                                                )
                if njson3 in self.orders_filerow:
                    verify_njson_len = verify_njson_len + 1
                    continue
            except:
                pass
            ###########################################################
            sym = order['symbol'][-3:]
            if not 'execValue' in order:
                pip_value = self.convert_usdt(order, pip_value_order, date )
            elif sym in ['USD']:
                value_order = float(order['sz']) * float(order['px'])
                exec_value = float(order['execValue']) if ImportParams.isfloat(order['execValue']) else ''
                pip_value_exec = (exec_value / value_order)
                pip_value = pip_value_exec * price
            elif 'USDT' in order['symbol']:
                pip_value = 1
            else:
                value_order = float(order['sz']) * float(order['px'])
                exec_value = float(order['execValue']) if ImportParams.isfloat(order['execValue']) else ''
                if exec_value and order['symbol']:                    
                    if value_order != 0:
                        pip_value = (exec_value / value_order)
                    else:
                        pip_value = exec_value
                        
                    pip_value = pip_value * float(order['price'])
                     
                    #pip_value_btc = ImportParams.pip_value_crypto_(order['date'], base='usdt', currency='btc')"""
            
            pip_value_btc = 1
            if order['symbol'] == 'BTCUSD':
                row_btc = '{}{}'.format(date,sym)
                if not row_btc in pip_value_order_btc:
                    pip_value_btc = ImportParams.pip_value_crypto_(order['date'], base='usdt', currency='btc')
                    pip_value_order_btc[row_btc] = pip_value_btc
                else:
                    pip_value_btc = pip_value_order_btc[row_btc]
                
            if len(order['symbol']) > 8 and len(order['symbol'].split('-')) == 4:
                split_symbol = order['symbol'].split('-')
                type = 'crypto option'
                option = 'CALL' if split_symbol[3].upper() == 'C' else 'PUT' if split_symbol[3].upper() == 'P' else 'OPTION'
                strike = split_symbol[2]
                expire = str(parse(split_symbol[1]).strftime("%d %b %y")).upper()
                order['symbol'] = '#{}{}'.format(split_symbol[0],'USDT')

            order['orderby'] = b
            order['trade_notes'] = order['trade_notes'] if 'trade_notes' in order else ''
            order['type_stock'] = type
            order['type_option'] = option
            order['action'] = action
            order['price'] = price
            order['shares'] = str(order['sz']).replace(',', '') if 'sz' in order else str(order['qty']).replace(',', '')
            
            if 'fee' in order:
                order['fees'] = (float(order['fee']) * float(pip_value_btc)) if 'fee' in order else '0.00'
            elif 'cumExecFee' in order:
                order['fees'] = order['cumExecFee'] if 'cumExecFee' in order else '0.00'
            else:
                order['fees'] = (float(order['feeAmount']) * float(pip_value_btc)) if 'feeAmount' in order else '0.00'

            if ImportParams.isfloat(order['fees']) and 'USD' in order['symbol'] \
                  and not 'USDT' in order['symbol'] and not 'BTCUSD' in order['symbol']:
                order['fees'] = float(order['fees']) * float(price)
            order['njson'] = njson
            order['decimal'] = decimal
            order['expire'] = expire
            order['strike'] = strike
            order['pip_value'] = pip_value
            order['original_file_row'] = original_file_row
            order['broker'] = self.params['broker']
            order['userid'] = self.params['get_session_userid']
            order['portfolio'] = user_portfolio
            order['app_broker'] = 1
            order['size'] = 1
            data_item = ImportParams.get_result_append(order)
            self.out_result.append(data_item)
            b = b - 1